|
|
|
@ -235,20 +235,19 @@ class Layer(core.Layer):
|
|
|
|
|
else:
|
|
|
|
|
object.__delattr__(self, name)
|
|
|
|
|
|
|
|
|
|
def state_dict(self, destination=None, prefix='', include_sublayers=True):
|
|
|
|
|
def state_dict(self, destination=None, include_sublayers=True):
|
|
|
|
|
if destination is None:
|
|
|
|
|
destination = collections.OrderedDict()
|
|
|
|
|
for name, data in self._parameters.items():
|
|
|
|
|
if data is not None:
|
|
|
|
|
destination[prefix + name] = data
|
|
|
|
|
destination[data.name] = data
|
|
|
|
|
|
|
|
|
|
if include_sublayers:
|
|
|
|
|
for layer_name, layer_item in self._sub_layers.items():
|
|
|
|
|
if layer_item is not None:
|
|
|
|
|
destination_temp = destination.copy()
|
|
|
|
|
destination_temp.update(
|
|
|
|
|
layer_item.state_dict(destination_temp, prefix +
|
|
|
|
|
layer_name + ".",
|
|
|
|
|
layer_item.state_dict(destination_temp,
|
|
|
|
|
include_sublayers))
|
|
|
|
|
destination = destination_temp
|
|
|
|
|
return destination
|
|
|
|
|