From e7a24611f42646245dde3710b448bb4473b5848e Mon Sep 17 00:00:00 2001 From: Ziyan Date: Wed, 9 Dec 2020 21:34:27 +0800 Subject: [PATCH] fix distribtued predict --- mindspore/parallel/_utils.py | 26 ++++++++++++++++++++------ mindspore/train/model.py | 6 ++++-- mindspore/train/serialization.py | 5 +++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index aaa9898ec0..e82ac20ee5 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -218,6 +218,11 @@ def _check_similar_layout(tensor_layout1, tensor_layout2): return True +def _check_same_layout(tensor_layout1, tensor_layout2): + """check if two tensor layouts are same""" + return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1] + + def _remove_repeated_slices(tensor_layout): """generate unrepeated tensor layout""" import copy @@ -236,9 +241,14 @@ def _infer_rank_list(train_map, predict_map=None): ret = {} for param_name in train_map: train_layout = train_map[param_name] - new_train_layout = _remove_repeated_slices(train_layout) + predict_layout = predict_map[param_name] train_dev_mat = train_layout[0] dev_num = np.array(train_dev_mat).prod() + if _check_same_layout(train_layout, predict_layout): + dev_rank = _get_global_rank() + ret[param_name] = ([dev_rank], True) + continue + new_train_layout = _remove_repeated_slices(train_layout) array = np.arange(dev_num).reshape(train_dev_mat) index = () for i in new_train_layout[0]: @@ -248,16 +258,20 @@ def _infer_rank_list(train_map, predict_map=None): index = index + (slice(None),) rank_list = array[index].flatten() if not predict_map: - ret[param_name] = rank_list + ret[param_name] = (rank_list, False) continue if param_name not in predict_map: logger.warning("predict_map does not contain %s", param_name) continue - predict_layout = predict_map[param_name] # optimization pass if _check_similar_layout(train_layout, predict_layout): - dev_rank = _get_global_rank() - ret[param_name] = [rank_list[dev_rank]] + if len(rank_list) == 1: + ret[param_name] = (rank_list, True) + elif len(rank_list) == dev_num: + dev_rank = _get_global_rank() + ret[param_name] = ([rank_list[dev_rank]], True) + else: + ret[param_name] = (rank_list, False) else: - ret[param_name] = rank_list + ret[param_name] = (rank_list, False) return ret diff --git a/mindspore/train/model.py b/mindspore/train/model.py index a8ed30346b..ee2c18be46 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -746,18 +746,20 @@ class Model: """ Generate parameter layout for the predict network in auto or semi auto parallel mode. - Data could be a single tensor, a list of tensor, or a tuple of tensor. + Data could be a single tensor or multiple tensors. Note: Batch data should be put together in one tensor. Args: - predict_data (Tensor): Tensor of predict data. can be array, list or tuple. + predict_data (Tensor): One tensor or multiple tensors of predict data. Returns: parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint Examples: + >>> context.set_context(mode=context.GRAPH_MODE) + >>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> model = Model(Net()) >>> model.infer_predict_layout(input_data) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 872f6d0b2d..546d85b075 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -950,11 +950,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= sliced_params = [] if param.name not in rank_list.keys(): continue - param_rank = rank_list[param.name] + param_rank = rank_list[param.name][0] + skip_merge_split = rank_list[param.name][1] for rank in param_rank: sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) sliced_params.append(sliced_param) - if len(sliced_params) == 1: + if skip_merge_split: split_param = sliced_params[0] else: param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])