|
|
|
@ -396,15 +396,17 @@ def _get_merged_param_data(net, param_name, param_data):
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, the combined tensor which with the whole data value.
|
|
|
|
|
"""
|
|
|
|
|
layout = []
|
|
|
|
|
layout = net.parameter_layout_dict[param_name]
|
|
|
|
|
if len(layout) < 2:
|
|
|
|
|
if len(layout) < 5:
|
|
|
|
|
logger.info("layout dict does not contain the key %s", param_name)
|
|
|
|
|
return param_data
|
|
|
|
|
|
|
|
|
|
dev_mat = layout[0]
|
|
|
|
|
tensor_map = layout[1]
|
|
|
|
|
field_size = layout[3]
|
|
|
|
|
uniform_split = layout[4]
|
|
|
|
|
if uniform_split[0] == 0:
|
|
|
|
|
raise RuntimeError("Save checkpoint only support uniform split tensor now.")
|
|
|
|
|
|
|
|
|
|
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|
|
|
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
|
|
|