|
|
|
@ -889,7 +889,6 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
|
|
|
raise ValueError(f"The sliced_parameters length should be equal to device_count. "
|
|
|
|
|
f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.")
|
|
|
|
|
|
|
|
|
|
merged_tensor = None
|
|
|
|
|
if not param_split_shape:
|
|
|
|
|
if not is_even:
|
|
|
|
|
raise ValueError("The shape of every parameter in sliced_parameters should be the same "
|
|
|
|
@ -1052,7 +1051,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
|
|
layerwise_parallel = sliced_parameters[0].layerwise_parallel
|
|
|
|
|
requires_grad = sliced_parameters[0].requires_grad
|
|
|
|
|
sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters]
|
|
|
|
|
merged_parameter = None
|
|
|
|
|
|
|
|
|
|
if not strategy:
|
|
|
|
|
merged_tensor = Tensor(np.concatenate(sliced_data))
|
|
|
|
@ -1121,7 +1119,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
|
|
param_rank = rank_list[param.name][0]
|
|
|
|
|
skip_merge_split = rank_list[param.name][1]
|
|
|
|
|
for rank in param_rank:
|
|
|
|
|
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
|
|
|
|
|
sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name]
|
|
|
|
|
sliced_params.append(sliced_param)
|
|
|
|
|
if skip_merge_split:
|
|
|
|
|
split_param = sliced_params[0]
|
|
|
|
@ -1213,59 +1211,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
|
|
layerwise_parallel = merged_param.layerwise_parallel
|
|
|
|
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
|
|
|
return split_param
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_single_param(ckpt_file_name, param_name):
|
|
|
|
|
"""Load a parameter from checkpoint."""
|
|
|
|
|
checkpoint_list = Checkpoint()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
with open(ckpt_file_name, "rb") as f:
|
|
|
|
|
pb_content = f.read()
|
|
|
|
|
checkpoint_list.ParseFromString(pb_content)
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to read the checkpoint file `%s` during load single parameter,"
|
|
|
|
|
" please check the correct of the file.", ckpt_file_name)
|
|
|
|
|
raise ValueError(e.__str__())
|
|
|
|
|
|
|
|
|
|
parameter = None
|
|
|
|
|
try:
|
|
|
|
|
param_data_list = []
|
|
|
|
|
for element_id, element in enumerate(checkpoint_list.value):
|
|
|
|
|
if element.tag != param_name:
|
|
|
|
|
continue
|
|
|
|
|
data = element.tensor.tensor_content
|
|
|
|
|
data_type = element.tensor.tensor_type
|
|
|
|
|
np_type = tensor_to_np_type[data_type]
|
|
|
|
|
ms_type = tensor_to_ms_type[data_type]
|
|
|
|
|
element_data = np.frombuffer(data, np_type)
|
|
|
|
|
param_data_list.append(element_data)
|
|
|
|
|
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
|
|
|
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
|
|
|
param_data = np.concatenate((param_data_list), axis=0)
|
|
|
|
|
param_data_list.clear()
|
|
|
|
|
dims = element.tensor.dims
|
|
|
|
|
if dims == [0]:
|
|
|
|
|
if 'Float' in data_type:
|
|
|
|
|
param_data = float(param_data[0])
|
|
|
|
|
elif 'Int' in data_type:
|
|
|
|
|
param_data = int(param_data[0])
|
|
|
|
|
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
|
|
|
elif dims == [1]:
|
|
|
|
|
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
|
|
|
else:
|
|
|
|
|
param_dim = []
|
|
|
|
|
for dim in dims:
|
|
|
|
|
param_dim.append(dim)
|
|
|
|
|
param_value = param_data.reshape(param_dim)
|
|
|
|
|
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
|
|
|
|
raise RuntimeError(e.__str__())
|
|
|
|
|
|
|
|
|
|
if parameter is None:
|
|
|
|
|
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
|
|
|
|
|
f"please check parameter name or checkpoint file.")
|
|
|
|
|
return parameter
|
|
|
|
|