From 017ff492afab70cf4c8ede9fc8117d70054305f5 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Mon, 22 Jun 2020 17:39:15 +0800 Subject: [PATCH] vm for MatrixDiag,MatrixDiagPart.MatrixSetDiag --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 5 +- mindspore/nn/layer/basic.py | 114 +++++++++++++- mindspore/ops/_op_impl/tbe/__init__.py | 3 + mindspore/ops/_op_impl/tbe/matrix_diag.py | 45 ++++++ .../ops/_op_impl/tbe/matrix_diag_part.py | 45 ++++++ mindspore/ops/_op_impl/tbe/matrix_set_diag.py | 46 ++++++ mindspore/ops/operations/_inner_ops.py | 141 ++++++++++++++++++ tests/ut/python/ops/test_nn_ops.py | 17 +++ tests/ut/python/ops/test_ops.py | 19 +++ 9 files changed, 433 insertions(+), 2 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/matrix_diag.py create mode 100644 mindspore/ops/_op_impl/tbe/matrix_diag_part.py create mode 100644 mindspore/ops/_op_impl/tbe/matrix_set_diag.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 9fd71e0e30..b7bad4fff8 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -124,7 +124,10 @@ static std::map tbe_func_adapter_map = { {"a_cos_grad", "acos_grad"}, {"histogram_fixed_width", "histogram_fixed_width_d"}, {"broadcast_to", "broadcast_to_d"}, - {"inplace_update", "inplace_update_d"}}; + {"inplace_update", "inplace_update_d"}, + {"matrix_diag", "matrix_diag_d"}, + {"matrix_diag_part", "matrix_diag_part_d"}, + {"matrix_set_diag", "matrix_set_diag_d"}}; void TbeAdapter::NormalizeFuncName(std::string *func_name) { if (func_name == nullptr) { diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 548fbcec1e..b1d5af48c9 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -31,9 +31,12 @@ from mindspore.ops import _selected_ops from ..cell import Cell from .activation import get_activation from ..._checkparam import Validator as validator +from ..._checkparam import Rel -__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold'] +__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', + 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] + class Dropout(Cell): r""" @@ -527,3 +530,112 @@ class Unfold(Cell): ret = self.extract_image_patches(x_transpose) ret_transpose = self.transpose(ret, self.format_NCHW) return ret_transpose + + +@constexpr +def _get_matrix_diag_assist(x_shape, x_dtype): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist") + base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) + assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) + return Tensor(assist, x_dtype) + + +@constexpr +def _get_matrix_diag_part_assist(x_shape, x_dtype): + validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist") + base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) + assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) + return Tensor(assist, x_dtype) + + +class MatrixDiag(Cell): + """ + Returns a batched diagonal tensor with a given batched diagonal values. + + Inputs: + - **x** (Tensor) - The diagonal values. It can be of the following data types: + float32, float16, int32, int8, uint8. + + Outputs: + Tensor, same type as input `x`. The shape should be x.shape + (x.shape[-1], ). + + Examples: + >>> x = Tensor(np.array([1, -1]), mstype.float32) + >>> matrix_diag = nn.MatrixDiag() + >>> result = matrix_diag(x) + [[1. 0.] + [0. -1.]] + """ + def __init__(self): + super(MatrixDiag, self).__init__() + self.matrix_diag = inner.MatrixDiag() + self.dtype = P.DType() + + def construct(self, input_x): + x_shape = F.shape(input_x) + x_dtype = self.dtype(input_x) + assist = _get_matrix_diag_assist(x_shape, x_dtype) + out_matrix_diag = self.matrix_diag(input_x, assist) + return out_matrix_diag + + +class MatrixDiagPart(Cell): + r""" + Returns the batched diagonal part of a batched tensor. + + Inputs: + - **x** (Tensor) - The batched tensor. It can be of the following data types: + float32, float16, int32, int8, uint8. + + Outputs: + Tensor, same type as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. + + Examples: + >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> matrix_diag_part = nn.MatrixDiagPart() + >>> result = matrix_diag_part(x) + [[-1., 1.], [-1., 1.], [-1., 1.]] + """ + def __init__(self): + super(MatrixDiagPart, self).__init__() + self.matrix_diag_part = inner.MatrixDiagPart() + self.dtype = P.DType() + + def construct(self, input_x): + x_shape = F.shape(input_x) + x_dtype = self.dtype(input_x) + assist = _get_matrix_diag_part_assist(x_shape, x_dtype) + out_matrix_diag_part = self.matrix_diag_part(input_x, assist) + return out_matrix_diag_part + + +class MatrixSetDiag(Cell): + r""" + Modify the batched diagonal part of a batched tensor. + + Inputs: + - **x** (Tensor) - The batched tensor. It can be of the following data types: + float32, float16, int32, int8, uint8. + - **diagonal** (Tensor) - The diagonal values. + + Outputs: + Tensor, same type as input `x`. The shape same as `x`. + + Examples: + >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) + >>> matrix_set_diag = nn.MatrixSetDiag() + >>> result = matrix_set_diag(x, diagonal) + [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] + """ + def __init__(self): + super(MatrixSetDiag, self).__init__() + self.matrix_set_diag = inner.MatrixSetDiag() + self.dtype = P.DType() + + def construct(self, input_x, diagonal): + x_shape = F.shape(input_x) + x_dtype = self.dtype(input_x) + assist = _get_matrix_diag_part_assist(x_shape, x_dtype) + out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist) + return out_matrix_set_diag diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 8351761935..631ec1bf44 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -264,3 +264,6 @@ from .inplace_update import _inplace_update_tbe from .splitv import _split_v_tbe from .in_top_k import _in_top_k_tbe from .lin_space import _lin_space_tbe +from .matrix_diag import _matrix_diag_tbe +from .matrix_diag_part import _matrix_diag_part_tbe +from .matrix_set_diag import _matrix_set_diag_tbe diff --git a/mindspore/ops/_op_impl/tbe/matrix_diag.py b/mindspore/ops/_op_impl/tbe/matrix_diag.py new file mode 100644 index 0000000000..9d080e34a2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/matrix_diag.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MatrixDiagD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +matrix_diag_d_op_info = TBERegOp("MatrixDiag") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("matrix_diag_d.so") \ + .compute_cost(10) \ + .kernel_name("matrix_diag_d") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "assist", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(matrix_diag_d_op_info) +def _matrix_diag_tbe(): + """MatrixDiagD TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/matrix_diag_part.py b/mindspore/ops/_op_impl/tbe/matrix_diag_part.py new file mode 100644 index 0000000000..1cb320bbce --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/matrix_diag_part.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MatrixDiagPartD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +matrix_diag_part_d_op_info = TBERegOp("MatrixDiagPart") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("matrix_diag_part_d.so") \ + .compute_cost(10) \ + .kernel_name("matrix_diag_part_d") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "assist", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(matrix_diag_part_d_op_info) +def _matrix_diag_part_tbe(): + """MatrixDiagPartD TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/matrix_set_diag.py b/mindspore/ops/_op_impl/tbe/matrix_set_diag.py new file mode 100644 index 0000000000..db0b460084 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/matrix_set_diag.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MatrixSetDiagD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +matrix_diag_d_op_info = TBERegOp("MatrixSetDiag") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("matrix_diag_d.so") \ + .compute_cost(10) \ + .kernel_name("matrix_diag_d") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "diagonal", False, "required", "all") \ + .input(2, "assist", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(matrix_diag_d_op_info) +def _matrix_set_diag_tbe(): + """MatrixSetDiagD TBE register""" + return diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index e89a104623..49834fc168 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -367,3 +367,144 @@ class LinSpace(PrimitiveWithInfer): args = {"assist": assist, "start": start, "stop": stop} validator.check_tensor_type_same(args, (mstype.float32,), self.name) return assist + + +class MatrixDiag(PrimitiveWithInfer): + """ + Returns a batched diagonal tensor with a given batched diagonal values. + + Inputs: + - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be of the following data types: + float32, float16, int32, int8, uint8. + - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and + it's last dimension must equal to the second to last dimension. + + Outputs: + Tensor, has the same type and shape as input `assist`. + + Examples: + >>> x = Tensor(np.array([1, -1]), mstype.float32) + >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) + >>> matrix_diag = P.MatrixDiag() + >>> result = matrix_diag(x, assist) + [[[-12. 11.] + [-10. 9.]] + [[ -8. 7.] + [ -6. 5.]] + [[ -4. 3.] + [ -2. 1.]]] + """ + + @prim_attr_register + def __init__(self): + """init MatrixDiag""" + + def infer_dtype(self, x_dtype, assist_dtype): + valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] + args = {"x": x_dtype, "assist": assist_dtype} + validator.check_tensor_type_same(args, valid_type, self.name) + return x_dtype + + def infer_shape(self, x_shape, assist_shape): + validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name) + validator.check('rank of x', len(x_shape)+1, + 'rank of assist', len(assist_shape), Rel.LE, self.name) + validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', + assist_shape[-1], Rel.EQ, self.name) + + r_end_dim = -len(x_shape) + r_idx = -1 + while r_idx >= r_end_dim: + if x_shape[r_idx] != 1: + validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % + assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name) + r_idx = r_idx - 1 + + return assist_shape + + +class MatrixDiagPart(PrimitiveWithInfer): + r""" + Returns the batched diagonal part of a batched tensor. + + Inputs: + - **x** (Tensor) - The batched tensor. It can be of the following data types: + float32, float16, int32, int8, uint8. + - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. + + Outputs: + Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. + + Examples: + >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) + >>> matrix_diag_part = P.MatrixDiagPart() + >>> result = matrix_diag_part(x, assist) + [[12., -9.], [8., -5.], [4., -1.]] + """ + + @prim_attr_register + def __init__(self): + """init MatrixDiagPart""" + + def infer_dtype(self, x_dtype, assist_dtype): + valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] + args = {"x": x_dtype, "assist": assist_dtype} + validator.check_tensor_type_same(args, valid_type, self.name) + return x_dtype + + def infer_shape(self, x_shape, assist_shape): + validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) + validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) + + if assist_shape[-2] < assist_shape[-1]: + out_shape = assist_shape[:-1] + else: + out_shape = assist_shape[:-2] + assist_shape[-1:] + return out_shape + + +class MatrixSetDiag(PrimitiveWithInfer): + r""" + Modify the batched diagonal part of a batched tensor. + + Inputs: + - **x** (Tensor) - The batched tensor. It can be of the following data types: + float32, float16, int32, int8, uint8. + - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. + - **diagonal** (Tensor) - The diagonal values. + + Outputs: + Tensor, data type same as input `x`. The shape same as `x`. + + Examples: + >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) + >>> matrix_set_diag = P.MatrixSetDiag() + >>> result = matrix_set_diag(x, diagonal) + [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] + + """ + + @prim_attr_register + def __init__(self): + """init MatrixSetDiag""" + + def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): + valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] + args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} + validator.check_tensor_type_same(args, valid_type, self.name) + return x_dtype + + def infer_shape(self, x_shape, diagonal_shape, assist_shape): + validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) + validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) + + if x_shape[-2] < x_shape[-1]: + validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape", + diagonal_shape, Rel.EQ, self.name) + else: + validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:], + "diagonal shape", diagonal_shape, Rel.EQ, self.name) + + return assist_shape diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index ab5eed0cd1..e950707234 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -370,6 +370,7 @@ def test_conv2d_same_primitive(): super(Conv2DSameNet, self).__init__() self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) + def construct(self, x, y): r1 = self.conv1(x) r2 = self.conv2(y) @@ -576,6 +577,22 @@ test_cases = [ Tensor(np.ones([1, 3, 4, 4], np.float32)), Tensor(np.ones(3, np.float32))], }), + ('MatrixDiag', { + 'block': nn.MatrixDiag(), + 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], + 'skip': ['backward'] + }), + ('MatrixDiagPart', { + 'block': nn.MatrixDiagPart(), + 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))], + 'skip': ['backward'] + }), + ('MatrixSetDiag', { + 'block': nn.MatrixSetDiag(), + 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)), + Tensor(np.array([1, 2]).astype(np.float32))], + 'skip': ['backward'] + }), ] test_cases_for_verify_exception = [ diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index cdbb818454..cf6a6705ab 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1612,6 +1612,25 @@ test_case_array_ops = [ Tensor(5, mstype.int32)], 'skip': ['backward'], }), + ('MatrixDiag', { + 'block': inner.MatrixDiag(), + 'desc_inputs': [Tensor(np.array([1, -1]), mstype.float32), + Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], + 'skip': ['backward'], + }), + ('MatrixDiagPart', { + 'block': inner.MatrixDiagPart(), + 'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32), + Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], + 'skip': ['backward'], + }), + ('MatrixSetDiag', { + 'block': inner.MatrixSetDiag(), + 'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32), + Tensor(np.arange(6).reshape(3, 2), mstype.float32), + Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], + 'skip': ['backward'], + }), ] test_case_other_ops = [