|
|
|
@ -264,11 +264,12 @@ class Cell:
|
|
|
|
|
logger.info("layout dict does not contain the key %s", key)
|
|
|
|
|
continue
|
|
|
|
|
if self.parameters_dict()[key].sliced:
|
|
|
|
|
logger.info("Param %s is from initializer, already sliced.", key)
|
|
|
|
|
logger.info("Param %s is already sliced.", key)
|
|
|
|
|
continue
|
|
|
|
|
layout = self.parameter_layout_dict[key]
|
|
|
|
|
new_tensor = _load_tensor_by_layout(tensor, layout)
|
|
|
|
|
self.parameters_dict()[key].set_parameter_data(new_tensor)
|
|
|
|
|
self.parameters_dict()[key].sliced = True
|
|
|
|
|
elif isinstance(params, OrderedDict):
|
|
|
|
|
for key in params:
|
|
|
|
|
tensor = params[key].data
|
|
|
|
@ -276,11 +277,12 @@ class Cell:
|
|
|
|
|
logger.info("layout dict does not contain the key %s", key)
|
|
|
|
|
continue
|
|
|
|
|
if params[key].sliced:
|
|
|
|
|
logger.info("Param %s is from initializer, already sliced.", key)
|
|
|
|
|
logger.info("Param %s is already sliced.", key)
|
|
|
|
|
continue
|
|
|
|
|
layout = self.parameter_layout_dict[key]
|
|
|
|
|
new_tensor = _load_tensor_by_layout(tensor, layout)
|
|
|
|
|
params[key].set_parameter_data(new_tensor)
|
|
|
|
|
params[key].sliced = True
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError('Parameters need OrderedDict type, but got {}'.
|
|
|
|
|
format(type(params)))
|
|
|
|
@ -435,14 +437,17 @@ class Cell:
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def init_parameters_data(self, recurse=True):
|
|
|
|
|
def init_parameters_data(self, recurse=True, auto_parallel_mode=False):
|
|
|
|
|
"""Init parameters' data."""
|
|
|
|
|
for param in self.get_parameters(expand=recurse):
|
|
|
|
|
if param.name not in self.parameter_layout_dict:
|
|
|
|
|
logger.info("Layout dict does not contain the key %s.", param.name)
|
|
|
|
|
if not auto_parallel_mode:
|
|
|
|
|
param.init_data()
|
|
|
|
|
elif param.name not in self.parameter_layout_dict:
|
|
|
|
|
logger.info("Layout dict does not contain the key %s.", param.name)
|
|
|
|
|
param.init_data(set_sliced=True)
|
|
|
|
|
else:
|
|
|
|
|
layout = self.parameter_layout_dict[param.name]
|
|
|
|
|
param.init_data(layout)
|
|
|
|
|
param.init_data(layout, set_sliced=True)
|
|
|
|
|
|
|
|
|
|
def parameters_dict(self, recurse=True):
|
|
|
|
|
"""
|
|
|
|
|