From cb391ba234e411234ea33e85a60df546efff8e09 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Fri, 27 Nov 2020 18:45:59 +0800 Subject: [PATCH] fix identity grad --- mindspore/ops/_grad/grad_array_ops.py | 10 ++++++++++ mindspore/ops/operations/array_ops.py | 14 ++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 7788613452..4ef13eec44 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -456,6 +456,16 @@ def get_bprop_sparse_gather_v2(self): return bprop +@bprop_getters.register(P.Identity) +def get_bprop_identity(self): + """Generate bprop for Identity""" + + def bprop(x, out, dout): + return (dout,) + + return bprop + + @bprop_getters.register(inner.Range) def get_bprop_range(self): """Generate bprop for Range""" diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a51729fa0e..86cae4b783 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4104,14 +4104,14 @@ class Meshgrid(PrimitiveWithInfer): Tensors, A Tuple of N N-D Tensor objects. Examples: - >>> x = np.array([1, 2, 3, 4]).astype(np.int32) - >>> y = np.array([5, 6, 7]).astype(np.int32) - >>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32) + >>> x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32)) + >>> y = Tensor(np.array([5, 6, 7]).astype(np.int32)) + >>> z = Tensor(np.array([8, 9, 0, 1, 2]).astype(np.int32)) >>> inputs = (x, y, z) >>> meshgrid = ops.Meshgrid(indexing="xy") >>> output = meshgrid(inputs) >>> print(output) - (Tensor(shape=[3, 4, 6], dtype=UInt32, value= + (Tensor(shape=[3, 4, 6], dtype=Int32, value= [[[1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], @@ -4124,7 +4124,7 @@ class Meshgrid(PrimitiveWithInfer): [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]]]), - Tensor(shape=[3, 4, 6], dtype=UInt32, value= + Tensor(shape=[3, 4, 6], dtype=Int32, value= [[[5, 5, 5, 5, 5], [5, 5, 5, 5, 5], [5, 5, 5, 5, 5], @@ -4137,7 +4137,7 @@ class Meshgrid(PrimitiveWithInfer): [7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 7, 7, 7, 7]]]), - Tensor(shape=[3, 4, 6], dtype=UInt32, value= + Tensor(shape=[3, 4, 6], dtype=Int32, value= [[[8, 9, 0, 1, 2], [8, 9, 0, 1, 2], [8, 9, 0, 1, 2], @@ -4611,6 +4611,8 @@ class Identity(PrimitiveWithInfer): """Initialize identity""" def __infer__(self, x): + validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) + validator.check_tensor_dtype_valid('x', x['dtype'], mstype.number_type + (mstype.bool_,), self.name) out = {'shape': x['shape'], 'dtype': x['dtype'], 'value': None}