!2747 refactor StridedSlice op

Merge pull request !2747 from zhangbuxue/refactor_the_StridedSlice_op
pull/2747/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 30fc1bd0c3

File diff suppressed because it is too large Load Diff

@ -35,25 +35,6 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
def test_tensor_scatter_update():
class TensorScatterUpdateNet(nn.Cell):
"""TensorScatterUpdate net definition"""
def __init__(self):
super(TensorScatterUpdateNet, self).__init__()
self.tensor_scatter_update = P.TensorScatterUpdate()
def construct(self, x, i, u):
out = self.tensor_scatter_update(x, i, u)
return out
net = TensorScatterUpdateNet()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
updates = Tensor(np.ones([2, 5], np.float32))
net(x, indices, updates)
class InputBackward(nn.Cell):
def __init__(self, network):
super(InputBackward, self).__init__()
@ -446,6 +427,7 @@ class SparseApplyAdagradNet(nn.Cell):
out = self.sparse_apply_adagrad(self.var, self.accum, grad, indices)
return out
class ApplyRMSNet(nn.Cell):
def __init__(self):
super(ApplyRMSNet, self).__init__()
@ -496,6 +478,60 @@ class NormalNet(nn.Cell):
return out
class StridedSliceNet(nn.Cell):
def __init__(self):
super(StridedSliceNet, self).__init__()
self.begins = (1, 2, 3, 2, 1)
self.ends = (5, 6, 7, 8, 9)
self.strides = (1, 2, 3, 2, 1)
self.strided_slice_0 = P.StridedSlice(begin_mask=3, end_mask=5, ellipsis_mask=4,
shrink_axis_mask=2, new_axis_mask=8)
self.strided_slice_1 = P.StridedSlice(begin_mask=5, end_mask=2, ellipsis_mask=2,
shrink_axis_mask=6, new_axis_mask=10)
self.strided_slice_2 = P.StridedSlice(begin_mask=3, end_mask=3, ellipsis_mask=4,
shrink_axis_mask=5, new_axis_mask=13)
self.strided_slice_3 = P.StridedSlice(begin_mask=0, end_mask=0, ellipsis_mask=4,
shrink_axis_mask=12, new_axis_mask=15)
self.const_0 = Tensor(np.ones([6, 8, 9, 1, 8], np.float32))
self.const_1 = Tensor(np.ones([5, 7, 8, 1, 8], np.float32))
self.const_2 = Tensor(np.ones([1, 3, 7, 8, 9, 1, 8], np.float32))
self.const_3 = Tensor(np.ones([1, 1, 6, 7, 8, 9, 1, 8], np.float32))
def construct(self, x):
out_0 = self.strided_slice_0(x, self.begins, self.ends, self.strides) + self.const_0
out_1 = self.strided_slice_1(x, self.begins, self.ends, self.strides) + self.const_1
out_2 = self.strided_slice_2(x, self.begins, self.ends, self.strides) + self.const_2
out_3 = self.strided_slice_3(x, self.begins, self.ends, self.strides) + self.const_3
return out_0, out_1, out_2, out_3
def test_strided_slice_const():
class StridedSLiceConstNet(nn.Cell):
"""StridedSLiceConstNet net definition"""
def __init__(self):
super(StridedSLiceConstNet, self).__init__()
self.begins = (0, 2, -5, 2, 1)
self.ends = (0, 6, 9, 8, 9)
self.strides = (1, 2, 1, 2, 1)
self.strided_slice = P.StridedSlice(begin_mask=2,
end_mask=6,
ellipsis_mask=4,
shrink_axis_mask=6,
new_axis_mask=18)
def construct(self, x):
out = self.strided_slice(x, self.begins, self.ends, self.strides)
return out
net = StridedSLiceConstNet()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.ones([6, 7, 8, 9, 10]), mstype.float32)
ret = net(x)
assert ret.shape == (0, 1, 7, 8, 9, 3, 1)
assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all()
test_case_math_ops = [
('BitwiseAnd', {
'block': P.BitwiseAnd(),
@ -1366,6 +1402,10 @@ test_case_nn_ops = [
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'skip': ['backward']}),
('StridedSliceNet', {
'block': StridedSliceNet(),
'desc_inputs': [[6, 7, 8, 9, 10]],
'skip': ['backward']}),
('OneHot', {
'block': P.OneHot(),
'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)],
@ -1763,7 +1803,7 @@ test_case_other_ops = [
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
('TensorScatterUpdate', {
'block': P.TensorScatterUpdate(),
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.ones([2, 5], np.float32) * 99)),
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
@ -1930,11 +1970,10 @@ test_case_other_ops = [
]
test_case_quant_ops = [
('AscendQuant_1', {
'block': inner.AscendQuant(0.5, 0.0, False, "Round"),
'desc_inputs': [Tensor(np.random.rand(1,2,4,4), mstype.float32)],
'desc_inputs': [Tensor(np.random.rand(1, 2, 4, 4), mstype.float32)],
'skip': ['backward']}),
('AscendQuant_2', {
'block': inner.AscendQuant(80.0, 10.0, True, "Round"),
@ -2027,6 +2066,18 @@ raise_set = [
'block': (nn.SSIM(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones((1, 3, 8, 8)), mstype.float32),
Tensor(np.ones((1, 3, 8, 8)), mstype.float32)]}),
('StridedSlice_0', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2.2, 3), (3, 4, 5), (1, 1, 1)],
'desc_inputs': [[4, 5, 6, 7]]}),
('StridedSlice_1', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1)],
'desc_inputs': [[4, 5, 6, 7]]}),
('StridedSlice_2', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1, 0)],
'desc_inputs': [[4, 5, 6, 7]]}),
]

@ -25,6 +25,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
class NetWorkSlicePositive(Cell):
def __init__(self):
super(NetWorkSlicePositive, self).__init__()
@ -1159,10 +1160,8 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError):
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
ex.value)
def test_tensor_slice_reduce_out_of_bounds_positive():
@ -1177,6 +1176,5 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError):
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)

Loading…
Cancel
Save