|
|
|
|
@ -140,10 +140,13 @@ class Layer(object):
|
|
|
|
|
if self.name is None:
|
|
|
|
|
return self.to_proto_impl(**kwargs)
|
|
|
|
|
elif isinstance(self, MemoryV2):
|
|
|
|
|
return self.to_proto_impl(**kwargs)
|
|
|
|
|
elif self.name not in context:
|
|
|
|
|
context[self.name] = self.to_proto_impl(**kwargs)
|
|
|
|
|
name = self.name + "#__memory__"
|
|
|
|
|
if name not in context:
|
|
|
|
|
context[name] = self.to_proto_impl(**kwargs)
|
|
|
|
|
return context[name]
|
|
|
|
|
|
|
|
|
|
if self.name not in context:
|
|
|
|
|
context[self.name] = self.to_proto_impl(**kwargs)
|
|
|
|
|
return context[self.name]
|
|
|
|
|
|
|
|
|
|
def to_proto_impl(self, **kwargs):
|
|
|
|
|
@ -256,9 +259,32 @@ class LayerOutputV2(Layer):
|
|
|
|
|
return self.layer_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticInputV2(Layer):
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
self.__parent_names__ = ['input']
|
|
|
|
|
other_kwargs = dict()
|
|
|
|
|
parent_layers = dict()
|
|
|
|
|
for pname in self.__parent_names__:
|
|
|
|
|
if kwargs.has_key(pname):
|
|
|
|
|
parent_layers[pname] = kwargs[pname]
|
|
|
|
|
for key in kwargs.keys():
|
|
|
|
|
if key not in self.__parent_names__:
|
|
|
|
|
other_kwargs[key] = kwargs[key]
|
|
|
|
|
self.__kwargs__ = other_kwargs
|
|
|
|
|
super(StaticInputV2, self).__init__(parent_layers=parent_layers)
|
|
|
|
|
|
|
|
|
|
def to_proto_impl(self, **kwargs):
|
|
|
|
|
args = dict()
|
|
|
|
|
for each in kwargs:
|
|
|
|
|
args[each] = kwargs[each]
|
|
|
|
|
for each in self.__kwargs__:
|
|
|
|
|
args[each] = self.__kwargs__[each]
|
|
|
|
|
return conf_helps.StaticInput(**args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RecurrentGroupV2(Layer):
|
|
|
|
|
def __init__(self, name, **kwargs):
|
|
|
|
|
self.__parent_names__ = ['input']
|
|
|
|
|
self.__parent_names__ = ['input', 'boot_layer']
|
|
|
|
|
other_kwargs = dict()
|
|
|
|
|
parent_layers = dict()
|
|
|
|
|
for pname in self.__parent_names__:
|
|
|
|
|
@ -443,7 +469,8 @@ layer_list = [
|
|
|
|
|
['nce', 'nce_layer', ['input', 'label']],
|
|
|
|
|
['hsigmoid', 'hsigmoid', ['input', 'label']],
|
|
|
|
|
# check layers
|
|
|
|
|
['eos', 'eos_layer', ['input']]
|
|
|
|
|
['eos', 'eos_layer', ['input']],
|
|
|
|
|
['gru_step_layer', 'gru_step_layer', ['input', 'output_mem']]
|
|
|
|
|
]
|
|
|
|
|
for l in layer_list:
|
|
|
|
|
globals()[l[0]] = __convert_to_v2__(l[1], l[2])
|
|
|
|
|
|