From 6affe1401c1073f58d6e220d3c21296a3fe38756 Mon Sep 17 00:00:00 2001 From: changzherui Date: Mon, 17 Aug 2020 23:39:36 +0800 Subject: [PATCH] add save checkpoint check durning parallel --- mindspore/train/serialization.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 1b4bcdcb52..19bb90fab4 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -396,15 +396,17 @@ def _get_merged_param_data(net, param_name, param_data): Returns: Tensor, the combined tensor which with the whole data value. """ - layout = [] layout = net.parameter_layout_dict[param_name] - if len(layout) < 2: + if len(layout) < 5: logger.info("layout dict does not contain the key %s", param_name) return param_data dev_mat = layout[0] tensor_map = layout[1] field_size = layout[3] + uniform_split = layout[4] + if uniform_split[0] == 0: + raise RuntimeError("Save checkpoint only support uniform split tensor now.") from mindspore.parallel._cell_wrapper import get_allgather_cell from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight