!10686 fix infer rank list typo and add testcase

From: @gong_zi_yan
Reviewed-by: @stsuteng,@zhunaipan
Signed-off-by: @stsuteng
pull/10686/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2e684df5b1

@ -259,7 +259,7 @@ def _infer_rank_list(train_map, predict_map=None):
logger.warning("predict_map does not contain %s", param_name)
continue
predict_layout = predict_map[param_name]
dev_num = np.array(predict_layout[0].prod())
dev_num = np.array(predict_layout[0]).prod()
# optimization pass
if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()

@ -19,6 +19,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Model
from mindspore.ops import operations as P
from mindspore import context
from mindspore.parallel._utils import _infer_rank_list
class Net(nn.Cell):
@ -71,3 +72,48 @@ def test_edge_case():
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
with pytest.raises(RuntimeError):
model.predict(inputs)
# standalone predict
def test_infer_rank_list1():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = None
rank_list = _infer_rank_list(train_map, predict_map)["weight"]
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
assert rank_list[1] is False
# similar layout: gpt3 prediction mode
def test_infer_rank_list2():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = {'weight': [[8], [-1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)
expect_map = {'weight': ([0], True)}
assert rank_list == expect_map
# same layout
def test_infer_rank_list3():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = {'weight': [[4, 8], [-1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)
expect_map = {'weight': ([0], True)}
assert rank_list == expect_map
# totally different layout
def test_infer_rank_list4():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = {'weight': [[2, 2], [1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)["weight"]
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
assert rank_list[1] is False
# full shape ckpt
def test_infer_rank_list5():
train_map = {'weight': [[8], [-1, -1]]}
predict_map = {'weight': [[2, 2], [1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)
expect_map = {'weight': ([0], False)}
assert rank_list == expect_map

Loading…
Cancel
Save