|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
import mindspore as ms
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
@ -112,6 +113,17 @@ def test_unsortedsegmentsum_model_parallel_index_slice_3d():
|
|
|
|
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_unsortedsegmentsum_model_parallel_index_slice_diff_inputs():
|
|
|
|
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
|
|
|
|
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
|
|
|
|
y = Tensor(np.ones((4, 4)), ms.int32)
|
|
|
|
|
num_segments = 16
|
|
|
|
|
strategy1 = (2, 2, 1)
|
|
|
|
|
strategy2 = (2, 4)
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_unsortedsegmentsum_model_parallel_vector_slice_2d():
|
|
|
|
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
|
|
|
|
x = Tensor(np.ones((4, 8)), ms.float32)
|
|
|
|
|