!14694 modify load parallel ckpt

From: @changzherui
Reviewed-by: @kingxian,@zhoufeng54
Signed-off-by: @kingxian
pull/14694/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 68676a0d99

@ -192,7 +192,6 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
parameter_proto->set_name(param_name);
SetParamToTensorProto(param, parameter_proto);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
} else {
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_name);

@ -183,7 +183,7 @@ class DatasetHelper:
>>> train_dataset = create_custom_dataset()
>>> set_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
>>> for next_element in set_helper:
>>> print(next_element)
... print(next_element)
"""
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):

@ -889,7 +889,6 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
raise ValueError(f"The sliced_parameters length should be equal to device_count. "
f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.")
merged_tensor = None
if not param_split_shape:
if not is_even:
raise ValueError("The shape of every parameter in sliced_parameters should be the same "
@ -1052,7 +1051,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
layerwise_parallel = sliced_parameters[0].layerwise_parallel
requires_grad = sliced_parameters[0].requires_grad
sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters]
merged_parameter = None
if not strategy:
merged_tensor = Tensor(np.concatenate(sliced_data))
@ -1121,7 +1119,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name]
sliced_params.append(sliced_param)
if skip_merge_split:
split_param = sliced_params[0]
@ -1213,59 +1211,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
layerwise_parallel = merged_param.layerwise_parallel
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
return split_param
def _load_single_param(ckpt_file_name, param_name):
"""Load a parameter from checkpoint."""
checkpoint_list = Checkpoint()
try:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file `%s` during load single parameter,"
" please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())
parameter = None
try:
param_data_list = []
for element_id, element in enumerate(checkpoint_list.value):
if element.tag != param_name:
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
ms_type = tensor_to_ms_type[data_type]
element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data)
if (element_id == len(checkpoint_list.value) - 1) or \
(element.tag != checkpoint_list.value[element_id + 1].tag):
param_data = np.concatenate((param_data_list), axis=0)
param_data_list.clear()
dims = element.tensor.dims
if dims == [0]:
if 'Float' in data_type:
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
else:
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
break
except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())
if parameter is None:
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
f"please check parameter name or checkpoint file.")
return parameter

Loading…
Cancel
Save