!11978 Add grad impl for op MatrixInverse

From: @yuan_shen_zhou
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
pull/11978/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c2d120e714

@ -165,6 +165,22 @@ def get_bprop_tensor_add(self):
return bprop return bprop
@bprop_getters.register(P.MatrixInverse)
def get_bprop_matrix_inverse(self):
"""Grad definition for `MatrixInverse` operation."""
batchmatmul_a = P.math_ops.BatchMatMul(transpose_a=True)
batchmatmul_b = P.math_ops.BatchMatMul(transpose_b=True)
neg = P.Neg()
def bprop(x, out, dout):
dx = batchmatmul_b(dout, out)
dx = batchmatmul_a(out, dx)
dx = neg(dx)
return dx
return bprop
@bprop_getters.register(P.Neg) @bprop_getters.register(P.Neg)
def get_bprop_neg(self): def get_bprop_neg(self):
"""Grad definition for `Neg` operation.""" """Grad definition for `Neg` operation."""

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -4136,6 +4136,9 @@ class MatrixInverse(PrimitiveWithInfer):
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown
result may be returned result may be returned
Note:
The parameter 'adjoint' is only supporting False right now. Because complex number is not supported at present.
Args: Args:
adjoint (bool) : An optional bool. Default: False. adjoint (bool) : An optional bool. Default: False.
@ -4146,6 +4149,9 @@ class MatrixInverse(PrimitiveWithInfer):
Outputs: Outputs:
Tensor, has the same type and shape as input `x`. Tensor, has the same type and shape as input `x`.
Supported Platforms:
``GPU``
Examples: Examples:
>>> mindspore.set_seed(1) >>> mindspore.set_seed(1)
>>> x = Tensor(np.random.uniform(-2, 2, (2, 2, 2)), mindspore.float32) >>> x = Tensor(np.random.uniform(-2, 2, (2, 2, 2)), mindspore.float32)
@ -4161,7 +4167,7 @@ class MatrixInverse(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, adjoint=False): def __init__(self, adjoint=False):
"""Initialize MatrixInverse""" """Initialize MatrixInverse"""
validator.check_value_type("adjoint", adjoint, [bool], self.name) validator.check_type_name("adjoint", adjoint, False, self.name)
self.adjoint = adjoint self.adjoint = adjoint
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):

Loading…
Cancel
Save