|
|
|
@ -98,7 +98,7 @@ class DataLayerV2(Layer):
|
|
|
|
|
|
|
|
|
|
super(DataLayerV2, self).__init__(name=name, parent_layers=dict())
|
|
|
|
|
|
|
|
|
|
def to_proto_impl(self, **kwargs):
|
|
|
|
|
def to_proto_impl(self, context=None, **kwargs):
|
|
|
|
|
args = dict()
|
|
|
|
|
args['size'] = self.type.dim
|
|
|
|
|
for each in kwargs:
|
|
|
|
@ -142,46 +142,11 @@ class WithExtraParent(Layer):
|
|
|
|
|
else:
|
|
|
|
|
return context[self.name]
|
|
|
|
|
|
|
|
|
|
# parse parents
|
|
|
|
|
kwargs = dict()
|
|
|
|
|
# parse extra_parent
|
|
|
|
|
for p in self.__extra_parent__:
|
|
|
|
|
p.to_proto(context=context)
|
|
|
|
|
|
|
|
|
|
for layer_name in self.__parent_layers__:
|
|
|
|
|
if not isinstance(self.__parent_layers__[layer_name],
|
|
|
|
|
collections.Sequence):
|
|
|
|
|
v1_layer = self.__parent_layers__[layer_name].to_proto(
|
|
|
|
|
context=context)
|
|
|
|
|
else:
|
|
|
|
|
v1_layer = map(lambda x: x.to_proto(context=context),
|
|
|
|
|
self.__parent_layers__[layer_name])
|
|
|
|
|
kwargs[layer_name] = v1_layer
|
|
|
|
|
|
|
|
|
|
# parse self
|
|
|
|
|
if self.context_name() is None:
|
|
|
|
|
return self.to_proto_impl(context=context, **kwargs)
|
|
|
|
|
elif self.context_name() not in context:
|
|
|
|
|
context[self.context_name()] = self.to_proto_impl(
|
|
|
|
|
context=context, **kwargs)
|
|
|
|
|
|
|
|
|
|
# parse children.
|
|
|
|
|
aaa = self.__children_layers__
|
|
|
|
|
for layer, pnames in self.__children_layers__:
|
|
|
|
|
drop = False
|
|
|
|
|
|
|
|
|
|
# child will only be parsed if all parents are in context.
|
|
|
|
|
for pname in pnames:
|
|
|
|
|
if pname not in context:
|
|
|
|
|
drop = True
|
|
|
|
|
break
|
|
|
|
|
if drop:
|
|
|
|
|
continue
|
|
|
|
|
layer.to_proto(context=context)
|
|
|
|
|
|
|
|
|
|
if self.use_context_name():
|
|
|
|
|
return context[self.context_name()]
|
|
|
|
|
else:
|
|
|
|
|
return context[self.name]
|
|
|
|
|
return super(WithExtraParent, self).to_proto(context=context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryV2(WithExtraParent):
|
|
|
|
@ -307,7 +272,7 @@ class MixedLayerV2(Layer):
|
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
|
|
|
self.finalized = True
|
|
|
|
|
|
|
|
|
|
def to_proto_impl(self, **kwargs):
|
|
|
|
|
def to_proto_impl(self, context=None, **kwargs):
|
|
|
|
|
args = dict()
|
|
|
|
|
for each in kwargs:
|
|
|
|
|
args[each] = kwargs[each]
|
|
|
|
@ -371,7 +336,7 @@ class RecurrentLayerOutput(Layer):
|
|
|
|
|
def context_name(self):
|
|
|
|
|
return self.__recurrent_name__ + ".end"
|
|
|
|
|
|
|
|
|
|
def to_proto_impl(self, **kwargs):
|
|
|
|
|
def to_proto_impl(self, context=None, **kwargs):
|
|
|
|
|
for l in self.__parents__:
|
|
|
|
|
RecurrentLayerGroupSetOutLink(l.name)
|
|
|
|
|
RecurrentLayerGroupEnd(name=self.__recurrent_name__)
|
|
|
|
|