!9742 fix distribute predict

From: @gong_zi_yan
Reviewed-by: @stsuteng,@zhunaipan,@stsuteng
Signed-off-by: @stsuteng,@stsuteng
pull/9742/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1ff38be903

@ -218,6 +218,11 @@ def _check_similar_layout(tensor_layout1, tensor_layout2):
return True
def _check_same_layout(tensor_layout1, tensor_layout2):
"""check if two tensor layouts are same"""
return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1]
def _remove_repeated_slices(tensor_layout):
"""generate unrepeated tensor layout"""
import copy
@ -236,9 +241,14 @@ def _infer_rank_list(train_map, predict_map=None):
ret = {}
for param_name in train_map:
train_layout = train_map[param_name]
new_train_layout = _remove_repeated_slices(train_layout)
predict_layout = predict_map[param_name]
train_dev_mat = train_layout[0]
dev_num = np.array(train_dev_mat).prod()
if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = ([dev_rank], True)
continue
new_train_layout = _remove_repeated_slices(train_layout)
array = np.arange(dev_num).reshape(train_dev_mat)
index = ()
for i in new_train_layout[0]:
@ -248,16 +258,20 @@ def _infer_rank_list(train_map, predict_map=None):
index = index + (slice(None),)
rank_list = array[index].flatten()
if not predict_map:
ret[param_name] = rank_list
ret[param_name] = (rank_list, False)
continue
if param_name not in predict_map:
logger.warning("predict_map does not contain %s", param_name)
continue
predict_layout = predict_map[param_name]
# optimization pass
if _check_similar_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = [rank_list[dev_rank]]
if len(rank_list) == 1:
ret[param_name] = (rank_list, True)
elif len(rank_list) == dev_num:
dev_rank = _get_global_rank()
ret[param_name] = ([rank_list[dev_rank]], True)
else:
ret[param_name] = (rank_list, False)
else:
ret[param_name] = rank_list
ret[param_name] = (rank_list, False)
return ret

@ -746,18 +746,20 @@ class Model:
"""
Generate parameter layout for the predict network in auto or semi auto parallel mode.
Data could be a single tensor, a list of tensor, or a tuple of tensor.
Data could be a single tensor or multiple tensors.
Note:
Batch data should be put together in one tensor.
Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
predict_data (Tensor): One tensor or multiple tensors of predict data.
Returns:
parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint
Examples:
>>> context.set_context(mode=context.GRAPH_MODE)
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net())
>>> model.infer_predict_layout(input_data)

@ -950,11 +950,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
sliced_params = []
if param.name not in rank_list.keys():
continue
param_rank = rank_list[param.name]
param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
sliced_params.append(sliced_param)
if len(sliced_params) == 1:
if skip_merge_split:
split_param = sliced_params[0]
else:
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])

Loading…
Cancel
Save