|
|
|
@ -225,15 +225,6 @@ def load_param_into_net(net, parameter_dict):
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
|
|
|
|
|
logger.info("Execute load parameter into net process.")
|
|
|
|
|
for name in parameter_dict:
|
|
|
|
|
for _, param in net.parameters_and_names():
|
|
|
|
|
if name == param.name and param.layerwise_parallel:
|
|
|
|
|
# layerwise parallel parameter data loaded from checkpoint file,
|
|
|
|
|
# was a complete(merged) data, need to be splited
|
|
|
|
|
new_param = parameter_dict[param.name]
|
|
|
|
|
_load_tensor_for_layerwise(new_param, param)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
param_not_load = []
|
|
|
|
|
for _, param in net.parameters_and_names():
|
|
|
|
|
if param.name in parameter_dict:
|
|
|
|
@ -363,34 +354,6 @@ def _get_merged_param_data(net, param_name, param_data):
|
|
|
|
|
return param_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_tensor_for_layerwise(new_param, old_param):
|
|
|
|
|
"""
|
|
|
|
|
Replaces parameters with sliced tensors by layerwise parallel strategies.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
new_param (Parameter): The new layerwise parallel parameter, will be loaded into net.
|
|
|
|
|
old_param(Parameter): The current parameter in the net.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(new_param.data, Tensor) or not isinstance(old_param.data, Tensor):
|
|
|
|
|
logger.error("Failed to combine the net and the parameters.")
|
|
|
|
|
msg = ("layerwise parallel parameter should be a Tensor, but got {}.".format(type(new_param.data)))
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
|
|
|
|
|
if old_param.data.shape() == new_param.data.shape():
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
from mindspore.parallel._tensor import _load_tensor
|
|
|
|
|
from mindspore.communication.management import get_group_size
|
|
|
|
|
dev_mat = [get_group_size()]
|
|
|
|
|
shape = new_param.data.shape()
|
|
|
|
|
for x in range(len(shape)): # dim 0 set 0, others set -1
|
|
|
|
|
if x:
|
|
|
|
|
tensor_map.append(-1)
|
|
|
|
|
|
|
|
|
|
new_tensor = _load_tensor(new_param.data, dev_mat, tensor_map)
|
|
|
|
|
new_param.set_parameter_data(new_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fill_param_into_net(net, parameter_list):
|
|
|
|
|
"""
|
|
|
|
|
Fills parameter_list into net.
|
|
|
|
|