diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 1b4bcdcb52..19bb90fab4 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -396,15 +396,17 @@ def _get_merged_param_data(net, param_name, param_data): Returns: Tensor, the combined tensor which with the whole data value. """ - layout = [] layout = net.parameter_layout_dict[param_name] - if len(layout) < 2: + if len(layout) < 5: logger.info("layout dict does not contain the key %s", param_name) return param_data dev_mat = layout[0] tensor_map = layout[1] field_size = layout[3] + uniform_split = layout[4] + if uniform_split[0] == 0: + raise RuntimeError("Save checkpoint only support uniform split tensor now.") from mindspore.parallel._cell_wrapper import get_allgather_cell from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight