diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index e6be460752..715f30c6f2 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -165,6 +165,22 @@ def get_bprop_tensor_add(self): return bprop +@bprop_getters.register(P.MatrixInverse) +def get_bprop_matrix_inverse(self): + """Grad definition for `MatrixInverse` operation.""" + batchmatmul_a = P.math_ops.BatchMatMul(transpose_a=True) + batchmatmul_b = P.math_ops.BatchMatMul(transpose_b=True) + neg = P.Neg() + + def bprop(x, out, dout): + dx = batchmatmul_b(dout, out) + dx = batchmatmul_a(out, dx) + dx = neg(dx) + return dx + + return bprop + + @bprop_getters.register(P.Neg) def get_bprop_neg(self): """Grad definition for `Neg` operation.""" diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index f281282677..fc5714e39c 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -4136,6 +4136,9 @@ class MatrixInverse(PrimitiveWithInfer): Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown result may be returned + Note: + The parameter 'adjoint' is only supporting False right now. Because complex number is not supported at present. + Args: adjoint (bool) : An optional bool. Default: False. @@ -4146,6 +4149,9 @@ class MatrixInverse(PrimitiveWithInfer): Outputs: Tensor, has the same type and shape as input `x`. + Supported Platforms: + ``GPU`` + Examples: >>> mindspore.set_seed(1) >>> x = Tensor(np.random.uniform(-2, 2, (2, 2, 2)), mindspore.float32) @@ -4161,7 +4167,7 @@ class MatrixInverse(PrimitiveWithInfer): @prim_attr_register def __init__(self, adjoint=False): """Initialize MatrixInverse""" - validator.check_value_type("adjoint", adjoint, [bool], self.name) + validator.check_type_name("adjoint", adjoint, False, self.name) self.adjoint = adjoint def infer_dtype(self, x_dtype):