|
|
|
@ -35,25 +35,6 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
|
|
|
|
|
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_tensor_scatter_update():
|
|
|
|
|
class TensorScatterUpdateNet(nn.Cell):
|
|
|
|
|
"""TensorScatterUpdate net definition"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(TensorScatterUpdateNet, self).__init__()
|
|
|
|
|
self.tensor_scatter_update = P.TensorScatterUpdate()
|
|
|
|
|
|
|
|
|
|
def construct(self, x, i, u):
|
|
|
|
|
out = self.tensor_scatter_update(x, i, u)
|
|
|
|
|
return out
|
|
|
|
|
net = TensorScatterUpdateNet()
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
|
|
|
|
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
|
|
|
|
|
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
|
|
|
|
|
updates = Tensor(np.ones([2, 5], np.float32))
|
|
|
|
|
net(x, indices, updates)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputBackward(nn.Cell):
|
|
|
|
|
def __init__(self, network):
|
|
|
|
|
super(InputBackward, self).__init__()
|
|
|
|
@ -446,6 +427,7 @@ class SparseApplyAdagradNet(nn.Cell):
|
|
|
|
|
out = self.sparse_apply_adagrad(self.var, self.accum, grad, indices)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ApplyRMSNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(ApplyRMSNet, self).__init__()
|
|
|
|
@ -496,6 +478,60 @@ class NormalNet(nn.Cell):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StridedSliceNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(StridedSliceNet, self).__init__()
|
|
|
|
|
self.begins = (1, 2, 3, 2, 1)
|
|
|
|
|
self.ends = (5, 6, 7, 8, 9)
|
|
|
|
|
self.strides = (1, 2, 3, 2, 1)
|
|
|
|
|
self.strided_slice_0 = P.StridedSlice(begin_mask=3, end_mask=5, ellipsis_mask=4,
|
|
|
|
|
shrink_axis_mask=2, new_axis_mask=8)
|
|
|
|
|
self.strided_slice_1 = P.StridedSlice(begin_mask=5, end_mask=2, ellipsis_mask=2,
|
|
|
|
|
shrink_axis_mask=6, new_axis_mask=10)
|
|
|
|
|
self.strided_slice_2 = P.StridedSlice(begin_mask=3, end_mask=3, ellipsis_mask=4,
|
|
|
|
|
shrink_axis_mask=5, new_axis_mask=13)
|
|
|
|
|
self.strided_slice_3 = P.StridedSlice(begin_mask=0, end_mask=0, ellipsis_mask=4,
|
|
|
|
|
shrink_axis_mask=12, new_axis_mask=15)
|
|
|
|
|
self.const_0 = Tensor(np.ones([6, 8, 9, 1, 8], np.float32))
|
|
|
|
|
self.const_1 = Tensor(np.ones([5, 7, 8, 1, 8], np.float32))
|
|
|
|
|
self.const_2 = Tensor(np.ones([1, 3, 7, 8, 9, 1, 8], np.float32))
|
|
|
|
|
self.const_3 = Tensor(np.ones([1, 1, 6, 7, 8, 9, 1, 8], np.float32))
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
out_0 = self.strided_slice_0(x, self.begins, self.ends, self.strides) + self.const_0
|
|
|
|
|
out_1 = self.strided_slice_1(x, self.begins, self.ends, self.strides) + self.const_1
|
|
|
|
|
out_2 = self.strided_slice_2(x, self.begins, self.ends, self.strides) + self.const_2
|
|
|
|
|
out_3 = self.strided_slice_3(x, self.begins, self.ends, self.strides) + self.const_3
|
|
|
|
|
return out_0, out_1, out_2, out_3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_strided_slice_const():
|
|
|
|
|
class StridedSLiceConstNet(nn.Cell):
|
|
|
|
|
"""StridedSLiceConstNet net definition"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(StridedSLiceConstNet, self).__init__()
|
|
|
|
|
self.begins = (0, 2, -5, 2, 1)
|
|
|
|
|
self.ends = (0, 6, 9, 8, 9)
|
|
|
|
|
self.strides = (1, 2, 1, 2, 1)
|
|
|
|
|
self.strided_slice = P.StridedSlice(begin_mask=2,
|
|
|
|
|
end_mask=6,
|
|
|
|
|
ellipsis_mask=4,
|
|
|
|
|
shrink_axis_mask=6,
|
|
|
|
|
new_axis_mask=18)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
out = self.strided_slice(x, self.begins, self.ends, self.strides)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
net = StridedSLiceConstNet()
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
|
|
|
|
x = Tensor(np.ones([6, 7, 8, 9, 10]), mstype.float32)
|
|
|
|
|
ret = net(x)
|
|
|
|
|
assert ret.shape == (0, 1, 7, 8, 9, 3, 1)
|
|
|
|
|
assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_case_math_ops = [
|
|
|
|
|
('BitwiseAnd', {
|
|
|
|
|
'block': P.BitwiseAnd(),
|
|
|
|
@ -1366,6 +1402,10 @@ test_case_nn_ops = [
|
|
|
|
|
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
|
|
|
|
|
'desc_bprop': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
|
|
|
|
|
'skip': ['backward']}),
|
|
|
|
|
('StridedSliceNet', {
|
|
|
|
|
'block': StridedSliceNet(),
|
|
|
|
|
'desc_inputs': [[6, 7, 8, 9, 10]],
|
|
|
|
|
'skip': ['backward']}),
|
|
|
|
|
('OneHot', {
|
|
|
|
|
'block': P.OneHot(),
|
|
|
|
|
'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)],
|
|
|
|
@ -1763,7 +1803,7 @@ test_case_other_ops = [
|
|
|
|
|
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
|
|
|
|
|
('TensorScatterUpdate', {
|
|
|
|
|
'block': P.TensorScatterUpdate(),
|
|
|
|
|
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
|
|
|
|
|
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
|
|
|
|
|
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
|
|
|
|
Tensor(np.ones([2, 5], np.float32) * 99)),
|
|
|
|
|
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
|
|
|
|
@ -1930,11 +1970,10 @@ test_case_other_ops = [
|
|
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_case_quant_ops = [
|
|
|
|
|
('AscendQuant_1', {
|
|
|
|
|
'block': inner.AscendQuant(0.5, 0.0, False, "Round"),
|
|
|
|
|
'desc_inputs': [Tensor(np.random.rand(1,2,4,4), mstype.float32)],
|
|
|
|
|
'desc_inputs': [Tensor(np.random.rand(1, 2, 4, 4), mstype.float32)],
|
|
|
|
|
'skip': ['backward']}),
|
|
|
|
|
('AscendQuant_2', {
|
|
|
|
|
'block': inner.AscendQuant(80.0, 10.0, True, "Round"),
|
|
|
|
@ -2027,6 +2066,18 @@ raise_set = [
|
|
|
|
|
'block': (nn.SSIM(), {'exception': ValueError}),
|
|
|
|
|
'desc_inputs': [Tensor(np.ones((1, 3, 8, 8)), mstype.float32),
|
|
|
|
|
Tensor(np.ones((1, 3, 8, 8)), mstype.float32)]}),
|
|
|
|
|
('StridedSlice_0', {
|
|
|
|
|
'block': (P.StridedSlice(), {'exception': ValueError}),
|
|
|
|
|
'desc_const': [(1, 2.2, 3), (3, 4, 5), (1, 1, 1)],
|
|
|
|
|
'desc_inputs': [[4, 5, 6, 7]]}),
|
|
|
|
|
('StridedSlice_1', {
|
|
|
|
|
'block': (P.StridedSlice(), {'exception': ValueError}),
|
|
|
|
|
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1)],
|
|
|
|
|
'desc_inputs': [[4, 5, 6, 7]]}),
|
|
|
|
|
('StridedSlice_2', {
|
|
|
|
|
'block': (P.StridedSlice(), {'exception': ValueError}),
|
|
|
|
|
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1, 0)],
|
|
|
|
|
'desc_inputs': [[4, 5, 6, 7]]}),
|
|
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|