|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
|
|
|
|
from mindspore import Tensor, Model
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.parallel._utils import _infer_rank_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
@ -71,3 +72,48 @@ def test_edge_case():
|
|
|
|
|
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
|
model.predict(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# standalone predict
|
|
|
|
|
def test_infer_rank_list1():
|
|
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]}
|
|
|
|
|
predict_map = None
|
|
|
|
|
rank_list = _infer_rank_list(train_map, predict_map)["weight"]
|
|
|
|
|
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
|
|
|
assert rank_list[1] is False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# similar layout: gpt3 prediction mode
|
|
|
|
|
def test_infer_rank_list2():
|
|
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]}
|
|
|
|
|
predict_map = {'weight': [[8], [-1, 0]]}
|
|
|
|
|
rank_list = _infer_rank_list(train_map, predict_map)
|
|
|
|
|
expect_map = {'weight': ([0], True)}
|
|
|
|
|
assert rank_list == expect_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# same layout
|
|
|
|
|
def test_infer_rank_list3():
|
|
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]}
|
|
|
|
|
predict_map = {'weight': [[4, 8], [-1, 0]]}
|
|
|
|
|
rank_list = _infer_rank_list(train_map, predict_map)
|
|
|
|
|
expect_map = {'weight': ([0], True)}
|
|
|
|
|
assert rank_list == expect_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# totally different layout
|
|
|
|
|
def test_infer_rank_list4():
|
|
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]}
|
|
|
|
|
predict_map = {'weight': [[2, 2], [1, 0]]}
|
|
|
|
|
rank_list = _infer_rank_list(train_map, predict_map)["weight"]
|
|
|
|
|
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
|
|
|
assert rank_list[1] is False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# full shape ckpt
|
|
|
|
|
def test_infer_rank_list5():
|
|
|
|
|
train_map = {'weight': [[8], [-1, -1]]}
|
|
|
|
|
predict_map = {'weight': [[2, 2], [1, 0]]}
|
|
|
|
|
rank_list = _infer_rank_list(train_map, predict_map)
|
|
|
|
|
expect_map = {'weight': ([0], False)}
|
|
|
|
|
assert rank_list == expect_map
|
|
|
|
|