diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index ad8691f087..039f7e35ca 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -242,8 +242,14 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors(y_backprop, x). - CheckArgsSize(primitive->name(), args_spec_list, 2); - return args_spec_list[1]->Broaden(); + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto dout = CheckArg(op_name, args_spec_list, 0); + auto out = CheckArg(op_name, args_spec_list, 1); + (void)CheckDtypeSame(op_name, out, dout); + (void)CheckShapeSame(op_name, out, dout); + + return out->Broaden(); } AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,