updated to test files + DynamicOpCheckFix

lintfix

Adding Dynamic_shape_depends
pull/9338/head
danishnxt 4 years ago
parent f6450a614b
commit bf5ceaae57

@ -216,7 +216,7 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn && ids_is_dyn;
bool op_is_dynamic = x_is_dyn || ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;
@ -297,7 +297,7 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn && ids_is_dyn;
bool op_is_dynamic = x_is_dyn || ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;
@ -374,7 +374,7 @@ AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const Pri
// check if dynamic shape
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
bool op_is_dynamic = x_is_dyn && ids_is_dyn;
bool op_is_dynamic = x_is_dyn || ids_is_dyn;
auto x_shape = x->shape()->shape();
ShapeVector shape;
int64_t num_segments_value = 0;

@ -1917,6 +1917,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
def __init__(self):
"""Initialize UnsortedSegmentMin"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, x, segment_ids, num_segments):
segment_ids_shape = segment_ids['shape']

@ -941,22 +941,32 @@ def test_gather2():
# Dynamic Shape testing ahead
class GatherNetDynamic1(nn.Cell):
def __init__(self):
super(GatherNetDynamic1, self).__init__()
class GatherNetDynamic(nn.Cell):
def __init__(self, axis=0, dyn_a=True, dyn_b=True):
super(GatherNetDynamic, self).__init__()
self.gather = P.GatherV2()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.to_dyn_1 = dyn_a
self.to_dyn_2 = dyn_b
self.axis = axis
def construct(self, x, indices):
# Testing only second input dynamic
indices_dyn = self.gpu_convert_to_dynamic_shape(indices)
return self.gather(x, indices_dyn, 0)
# testing selective inputs being dynamic
if self.to_dyn_1:
x = self.gpu_convert_to_dynamic_shape(x)
if self.to_dyn_2:
indices = self.gpu_convert_to_dynamic_shape(indices)
return self.gather(x, indices, self.axis)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_dynamic_1():
def test_gatherV2_dyn_ab():
"""
Tests for Dynamic shape with both inputs dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNetDynamic()
x = Tensor(np.array([[4., 5., 4., 1., 5.,],
[4., 9., 5., 6., 4.,],
[9., 8., 4., 3., 6.,],
@ -968,14 +978,10 @@ def test_gather_dynamic_1():
[3., 7., 2., 7., 4.,],
[4., 2., 8., 2., 9.,]]
).astype(np.float32))
indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
expect = np.array([[[0., 0., 0., 0., 0.],
[4., 9., 5., 6., 4.],
[0., 0., 0., 0., 0.]]])
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNetDynamic1()
output = gather(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
@ -983,21 +989,44 @@ def test_gather_dynamic_1():
assert np.all(-diff < error)
class GatherNetDynamic2(nn.Cell):
def __init__(self):
super(GatherNetDynamic2, self).__init__()
self.gather = P.GatherV2()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
def construct(self, x, indices):
# Testing only first input dynamic
x_dyn = self.gpu_convert_to_dynamic_shape(x)
return self.gather(x_dyn, indices, -1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_dynamic_2():
def test_gatherV2_dyn_a():
"""
Tests for Dynamic shape with only first input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNetDynamic(-1, True, False)
# test 1
x = Tensor(np.array([[4., 5., 4., 1., 5.,],
[4., 9., 5., 6., 4.,],
[9., 8., 4., 3., 6.,],
[0., 4., 2., 2., 8.,],
[1., 8., 6., 2., 8.,],
[8., 1., 9., 7., 3.,],
[7., 9., 2., 5., 7.,],
[9., 8., 6., 8., 5.,],
[3., 7., 2., 7., 4.,],
[4., 2., 8., 2., 9.,]]
).astype(np.float32))
indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
expect = np.array([[[0., 5., 0.]],
[[0., 9., 0.]],
[[0., 8., 0.]],
[[0., 4., 0.]],
[[0., 8., 0.]],
[[0., 1., 0.]],
[[0., 9., 0.]],
[[0., 8., 0.]],
[[0., 7., 0.]],
[[0., 2., 0.]]]).astype(np.float32)
output = gather(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
# test 2
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
expect = np.array([[[[1., 3., 4.],
@ -1029,8 +1058,6 @@ def test_gather_dynamic_2():
[106., 108., 109.],
[111., 113., 114.],
[116., 118., 119.]]]])
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNetDynamic2()
output = gather(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
@ -1038,56 +1065,70 @@ def test_gather_dynamic_2():
assert np.all(-diff < error)
class GatherNetDynamic3(nn.Cell):
def __init__(self):
super(GatherNetDynamic3, self).__init__()
self.gather = P.GatherV2()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
def construct(self, x, indices):
# Testing both inputs dynamic shapes
x_dyn = self.gpu_convert_to_dynamic_shape(x)
indices_dyn = self.gpu_convert_to_dynamic_shape(indices)
return self.gather(x_dyn, indices_dyn, -1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_dynamic_3():
def test_gatherV2_dyn_b():
"""
Tests for Dynamic shape with only second input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNetDynamic(-1, False, True)
# test 1
x = Tensor(np.array([[4., 5., 4., 1., 5.,],
[4., 9., 5., 6., 4.,],
[9., 8., 4., 3., 6.,],
[0., 4., 2., 2., 8.,],
[1., 8., 6., 2., 8.,],
[8., 1., 9., 7., 3.,],
[7., 9., 2., 5., 7.,],
[9., 8., 6., 8., 5.,],
[3., 7., 2., 7., 4.,],
[4., 2., 8., 2., 9.,]]
).astype(np.float32))
indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
expect = np.array([[[0., 5., 0.]],
[[0., 9., 0.]],
[[0., 8., 0.]],
[[0., 4., 0.]],
[[0., 8., 0.]],
[[0., 1., 0.]],
[[0., 9., 0.]],
[[0., 8., 0.]],
[[0., 7., 0.]],
[[0., 2., 0.]]]).astype(np.float32)
output = gather(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
# test 2
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
expect = np.array([[[[1., 3., 4.],
[6., 8., 9.],
[11., 13., 14.],
[16., 18., 19.]],
[[21., 23., 24.],
[26., 28., 29.],
[31., 33., 34.],
[36., 38., 39.]],
[[41., 43., 44.],
[46., 48., 49.],
[51., 53., 54.],
[56., 58., 59.]]],
[[[61., 63., 64.],
[66., 68., 69.],
[71., 73., 74.],
[76., 78., 79.]],
[[81., 83., 84.],
[86., 88., 89.],
[91., 93., 94.],
[96., 98., 99.]],
[[101., 103., 104.],
[106., 108., 109.],
[111., 113., 114.],
[116., 118., 119.]]]])
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNetDynamic3()
output = gather(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect

@ -204,27 +204,36 @@ def test_3d_single_init():
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# For testing Dynamic Shape operation
class UnsortedSegmentMaxDynNet(nn.Cell):
def __init__(self, num_segments):
def __init__(self, num_segments, dyn_a=True, dyn_b=True):
super(UnsortedSegmentMaxDynNet, self).__init__()
self.unsorted_segment_max = P.UnsortedSegmentMax()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.num_segments = num_segments
self.to_dyn_1 = dyn_a
self.to_dyn_2 = dyn_b
def construct(self, data, ids):
dyn_data = self.gpu_convert_to_dynamic_shape(data)
dyn_ids = self.gpu_convert_to_dynamic_shape(ids)
return self.unsorted_segment_max(dyn_data, dyn_ids, self.num_segments)
# testing selective inputs being dynamic
if self.to_dyn_1:
data = self.gpu_convert_to_dynamic_shape(data)
if self.to_dyn_2:
ids = self.gpu_convert_to_dynamic_shape(ids)
return self.unsorted_segment_max(data, ids, self.num_segments)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_float32_dyn():
def test_3d_float32_dyn_ab():
"""
Tests for Dynamic shape with both inputs dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
num_segments = 4
net = UnsortedSegmentMaxDynNet(num_segments)
# input 1
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
@ -251,16 +260,21 @@ def test_3d_float32_dyn():
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init_dyn():
def test_3d_single_init_dyn_a():
"""
Tests for Dynamic shape with first input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
# test 1
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
num_segments = 4
net = UnsortedSegmentMaxDynNet(num_segments)
net = UnsortedSegmentMaxDynNet(num_segments, True, False)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
@ -283,8 +297,79 @@ def test_3d_single_init_dyn():
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# test 2
input_x = Tensor(np.arange(
4 * 7 * 2, dtype=np.float32).reshape(4, 7, 2), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.4000000e+01, 1.5000000e+01],
[1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01],
[2.0000000e+01, 2.1000000e+01],
[2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01],
[2.6000000e+01, 2.7000000e+01]],
[[2.8000000e+01, 2.9000000e+01],
[3.0000000e+01, 3.1000000e+01],
[3.2000000e+01, 3.3000000e+01],
[3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01],
[3.8000000e+01, 3.9000000e+01],
[4.0000000e+01, 4.1000000e+01]],
[[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00],
[2.0000000e+00, 3.0000000e+00],
[4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00],
[8.0000000e+00, 9.0000000e+00],
[1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# changing the input shape here for same net
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init_dyn_b():
"""
Tests for Dynamic shape with second input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
# input 1
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
num_segments = 4
net = UnsortedSegmentMaxDynNet(num_segments, False, True)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# input 2
input_x = Tensor(np.arange(
4 * 7 * 2, dtype=np.float32).reshape(4, 7, 2), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)

@ -207,22 +207,29 @@ def test_3d_single_init():
# For testing Dynamic Shape operation
class UnsortedSegmentMinDynNet(nn.Cell):
def __init__(self, num_segments):
def __init__(self, num_segments, dyn_a=True, dyn_b=True):
super(UnsortedSegmentMinDynNet, self).__init__()
self.unsorted_segment_min = P.UnsortedSegmentMin()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.num_segments = num_segments
self.to_dyn_1 = dyn_a
self.to_dyn_2 = dyn_b
def construct(self, data, ids):
dyn_data = self.gpu_convert_to_dynamic_shape(data)
dyn_ids = self.gpu_convert_to_dynamic_shape(ids)
return self.unsorted_segment_min(dyn_data, dyn_ids, self.num_segments)
# testing selective inputs being dynamic
if self.to_dyn_1:
data = self.gpu_convert_to_dynamic_shape(data)
if self.to_dyn_2:
ids = self.gpu_convert_to_dynamic_shape(ids)
return self.unsorted_segment_min(data, ids, self.num_segments)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_float32_dyn():
def test_3d_float32_ab_dyn():
"""
Test for Dynamic shape with both inputs dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
@ -251,11 +258,14 @@ def test_3d_float32_dyn():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_single_init_dyn():
def test_3d_float32_a_dyn():
"""
Tests for Dynamic shape with only first input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
num_segments = 4
net = UnsortedSegmentMinDynNet(num_segments)
net = UnsortedSegmentMinDynNet(num_segments, True, False)
# test 1
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
@ -281,8 +291,79 @@ def test_3d_single_init_dyn():
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# test 2
input_x = Tensor(np.arange(
4 * 7 * 2, dtype=np.float32).reshape(4, 7, 2), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.4000000e+01, 1.5000000e+01],
[1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01],
[2.0000000e+01, 2.1000000e+01],
[2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01],
[2.6000000e+01, 2.7000000e+01]],
[[2.8000000e+01, 2.9000000e+01],
[3.0000000e+01, 3.1000000e+01],
[3.2000000e+01, 3.3000000e+01],
[3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01],
[3.8000000e+01, 3.9000000e+01],
[4.0000000e+01, 4.1000000e+01]],
[[3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00],
[2.0000000e+00, 3.0000000e+00],
[4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00],
[8.0000000e+00, 9.0000000e+00],
[1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# changing the input shape here for same net
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_float32_b_dyn():
"""
Tests for Dynamic shape with only second input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
num_segments = 4
net = UnsortedSegmentMinDynNet(num_segments, False, True)
# test 1
input_x = Tensor(np.arange(
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
output = net(input_x, segment_ids).asnumpy()
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
[[3.4028235e+38, 3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38, 3.4028235e+38],
[3.4028235e+38, 3.4028235e+38, 3.4028235e+38]],
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
np.testing.assert_array_almost_equal(output, expect)
# test 2
input_x = Tensor(np.arange(
4 * 7 * 2, dtype=np.float32).reshape(4, 7, 2), dtype=mindspore.float32)
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)

@ -83,25 +83,21 @@ def test_3D():
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[45., 47., 49.],
[51., 53., 55.],
[57., 59., 61.],
[63., 65., 67.],
[69., 71., 73.]],
[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.],
[9., 10., 11.],
[12., 13., 14.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
@ -112,32 +108,39 @@ def test_3D():
# Testing Dynamic Shape
class UnsortedSegmentSumDynNet(nn.Cell):
def __init__(self, num_segments):
def __init__(self, num_segments, dyn_a=True, dyn_b=True):
super(UnsortedSegmentSumDynNet, self).__init__()
self.unsorted_segment_sum = P.UnsortedSegmentSum()
self.to_dyn_op = inner.GpuConvertToDynamicShape()
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.num_segments = num_segments
self.to_dyn_1 = dyn_a
self.to_dyn_2 = dyn_b
def construct(self, data, ids):
data_dyn = self.to_dyn_op(data)
ids_dyn = self.to_dyn_op(ids)
return self.unsorted_segment_sum(data_dyn, ids_dyn, self.num_segments)
# testing selective inputs being dynamic
if self.to_dyn_1:
data = self.gpu_convert_to_dynamic_shape(data)
if self.to_dyn_2:
ids = self.gpu_convert_to_dynamic_shape(ids)
return self.unsorted_segment_sum(data, ids, self.num_segments)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dyn():
def test_dyn_ab():
"""
Tests for Dynamic shape with both inputs dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
num_segments = 4
net = UnsortedSegmentSumDynNet(num_segments)
# test 1
input_x = Tensor([1, 2, 3, 4], mstype.float32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
output = net(input_x, segment_ids)
expect = [3, 3, 4, 0]
assert (output.asnumpy() == expect).all()
# test 2
input_x = Tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], mstype.float32)
@ -148,7 +151,7 @@ def test_dyn():
[1, 2, 3, 4],
[0, 0, 0, 0]]
assert (output.asnumpy() == expect).all()
# test 3
input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3))
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
output = net(input_x, segment_ids)
@ -157,19 +160,16 @@ def test_dyn():
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[45., 47., 49.],
[51., 53., 55.],
[57., 59., 61.],
[63., 65., 67.],
[69., 71., 73.]],
[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.],
[9., 10., 11.],
[12., 13., 14.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
@ -181,17 +181,20 @@ def test_dyn():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dyn_1():
def test_dyn_a():
"""
Tests for Dynamic shape with first input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
num_segments = 6
net = UnsortedSegmentSumDynNet(num_segments)
net = UnsortedSegmentSumDynNet(num_segments, True, False)
# test 1
input_x = Tensor([1, 2, 3, 4], mstype.float32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
output = net(input_x, segment_ids)
expect = [3, 3, 4, 0, 0, 0]
assert (output.asnumpy() == expect).all()
# test 2
input_x = Tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], mstype.float32)
@ -204,7 +207,7 @@ def test_dyn_1():
[0, 0, 0, 0],
[0, 0, 0, 0]]
assert (output.asnumpy() == expect).all()
# test 3
input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3))
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
output = net(input_x, segment_ids)
@ -213,31 +216,92 @@ def test_dyn_1():
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[45., 47., 49.],
[51., 53., 55.],
[57., 59., 61.],
[63., 65., 67.],
[69., 71., 73.]],
[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.],
[9., 10., 11.],
[12., 13., 14.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dyn_b():
"""
Tests for Dynamic shape with second input dynamic
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
num_segments = 6
net = UnsortedSegmentSumDynNet(num_segments, False, True)
# test 1
input_x = Tensor([1, 2, 3, 4], mstype.float32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
output = net(input_x, segment_ids)
expect = [3, 3, 4, 0, 0, 0]
assert (output.asnumpy() == expect).all()
# test 2
input_x = Tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], mstype.float32)
segment_ids = Tensor([2, 1, 1], mstype.int32)
output = net(input_x, segment_ids)
expect = [[0, 0, 0, 0],
[14, 16, 18, 20],
[1, 2, 3, 4],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]
assert (output.asnumpy() == expect).all()
# test 3
input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3))
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
output = net(input_x, segment_ids)
expect = [[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[45., 47., 49.],
[51., 53., 55.],
[57., 59., 61.],
[63., 65., 67.],
[69., 71., 73.]],
[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.],
[9., 10., 11.],
[12., 13., 14.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],

Loading…
Cancel
Save