!9742 fix distribute predict

From: @gong_zi_yan
Reviewed-by: @stsuteng,@zhunaipan,@stsuteng
Signed-off-by: @stsuteng,@stsuteng
pull/9742/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1ff38be903

@ -218,6 +218,11 @@ def _check_similar_layout(tensor_layout1, tensor_layout2):
return True return True
def _check_same_layout(tensor_layout1, tensor_layout2):
"""check if two tensor layouts are same"""
return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1]
def _remove_repeated_slices(tensor_layout): def _remove_repeated_slices(tensor_layout):
"""generate unrepeated tensor layout""" """generate unrepeated tensor layout"""
import copy import copy
@ -236,9 +241,14 @@ def _infer_rank_list(train_map, predict_map=None):
ret = {} ret = {}
for param_name in train_map: for param_name in train_map:
train_layout = train_map[param_name] train_layout = train_map[param_name]
new_train_layout = _remove_repeated_slices(train_layout) predict_layout = predict_map[param_name]
train_dev_mat = train_layout[0] train_dev_mat = train_layout[0]
dev_num = np.array(train_dev_mat).prod() dev_num = np.array(train_dev_mat).prod()
if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = ([dev_rank], True)
continue
new_train_layout = _remove_repeated_slices(train_layout)
array = np.arange(dev_num).reshape(train_dev_mat) array = np.arange(dev_num).reshape(train_dev_mat)
index = () index = ()
for i in new_train_layout[0]: for i in new_train_layout[0]:
@ -248,16 +258,20 @@ def _infer_rank_list(train_map, predict_map=None):
index = index + (slice(None),) index = index + (slice(None),)
rank_list = array[index].flatten() rank_list = array[index].flatten()
if not predict_map: if not predict_map:
ret[param_name] = rank_list ret[param_name] = (rank_list, False)
continue continue
if param_name not in predict_map: if param_name not in predict_map:
logger.warning("predict_map does not contain %s", param_name) logger.warning("predict_map does not contain %s", param_name)
continue continue
predict_layout = predict_map[param_name]
# optimization pass # optimization pass
if _check_similar_layout(train_layout, predict_layout): if _check_similar_layout(train_layout, predict_layout):
if len(rank_list) == 1:
ret[param_name] = (rank_list, True)
elif len(rank_list) == dev_num:
dev_rank = _get_global_rank() dev_rank = _get_global_rank()
ret[param_name] = [rank_list[dev_rank]] ret[param_name] = ([rank_list[dev_rank]], True)
else:
ret[param_name] = (rank_list, False)
else: else:
ret[param_name] = rank_list ret[param_name] = (rank_list, False)
return ret return ret

@ -746,18 +746,20 @@ class Model:
""" """
Generate parameter layout for the predict network in auto or semi auto parallel mode. Generate parameter layout for the predict network in auto or semi auto parallel mode.
Data could be a single tensor, a list of tensor, or a tuple of tensor. Data could be a single tensor or multiple tensors.
Note: Note:
Batch data should be put together in one tensor. Batch data should be put together in one tensor.
Args: Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple. predict_data (Tensor): One tensor or multiple tensors of predict data.
Returns: Returns:
parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint
Examples: Examples:
>>> context.set_context(mode=context.GRAPH_MODE)
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net()) >>> model = Model(Net())
>>> model.infer_predict_layout(input_data) >>> model.infer_predict_layout(input_data)

@ -950,11 +950,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
sliced_params = [] sliced_params = []
if param.name not in rank_list.keys(): if param.name not in rank_list.keys():
continue continue
param_rank = rank_list[param.name] param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
for rank in param_rank: for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
sliced_params.append(sliced_param) sliced_params.append(sliced_param)
if len(sliced_params) == 1: if skip_merge_split:
split_param = sliced_params[0] split_param = sliced_params[0]
else: else:
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name]) param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])

Loading…
Cancel
Save