|
|
|
@ -218,6 +218,11 @@ def _check_similar_layout(tensor_layout1, tensor_layout2):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_same_layout(tensor_layout1, tensor_layout2):
|
|
|
|
|
"""check if two tensor layouts are same"""
|
|
|
|
|
return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_repeated_slices(tensor_layout):
|
|
|
|
|
"""generate unrepeated tensor layout"""
|
|
|
|
|
import copy
|
|
|
|
@ -236,9 +241,14 @@ def _infer_rank_list(train_map, predict_map=None):
|
|
|
|
|
ret = {}
|
|
|
|
|
for param_name in train_map:
|
|
|
|
|
train_layout = train_map[param_name]
|
|
|
|
|
new_train_layout = _remove_repeated_slices(train_layout)
|
|
|
|
|
predict_layout = predict_map[param_name]
|
|
|
|
|
train_dev_mat = train_layout[0]
|
|
|
|
|
dev_num = np.array(train_dev_mat).prod()
|
|
|
|
|
if _check_same_layout(train_layout, predict_layout):
|
|
|
|
|
dev_rank = _get_global_rank()
|
|
|
|
|
ret[param_name] = ([dev_rank], True)
|
|
|
|
|
continue
|
|
|
|
|
new_train_layout = _remove_repeated_slices(train_layout)
|
|
|
|
|
array = np.arange(dev_num).reshape(train_dev_mat)
|
|
|
|
|
index = ()
|
|
|
|
|
for i in new_train_layout[0]:
|
|
|
|
@ -248,16 +258,20 @@ def _infer_rank_list(train_map, predict_map=None):
|
|
|
|
|
index = index + (slice(None),)
|
|
|
|
|
rank_list = array[index].flatten()
|
|
|
|
|
if not predict_map:
|
|
|
|
|
ret[param_name] = rank_list
|
|
|
|
|
ret[param_name] = (rank_list, False)
|
|
|
|
|
continue
|
|
|
|
|
if param_name not in predict_map:
|
|
|
|
|
logger.warning("predict_map does not contain %s", param_name)
|
|
|
|
|
continue
|
|
|
|
|
predict_layout = predict_map[param_name]
|
|
|
|
|
# optimization pass
|
|
|
|
|
if _check_similar_layout(train_layout, predict_layout):
|
|
|
|
|
dev_rank = _get_global_rank()
|
|
|
|
|
ret[param_name] = [rank_list[dev_rank]]
|
|
|
|
|
if len(rank_list) == 1:
|
|
|
|
|
ret[param_name] = (rank_list, True)
|
|
|
|
|
elif len(rank_list) == dev_num:
|
|
|
|
|
dev_rank = _get_global_rank()
|
|
|
|
|
ret[param_name] = ([rank_list[dev_rank]], True)
|
|
|
|
|
else:
|
|
|
|
|
ret[param_name] = (rank_list, False)
|
|
|
|
|
else:
|
|
|
|
|
ret[param_name] = rank_list
|
|
|
|
|
ret[param_name] = (rank_list, False)
|
|
|
|
|
return ret
|
|
|
|
|