add shape type check for relugrad

pull/5995/head
panyifeng 5 years ago
parent b7425d3e0c
commit 90943f799a

@ -242,8 +242,14 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors(y_backprop, x).
CheckArgsSize(primitive->name(), args_spec_list, 2);
return args_spec_list[1]->Broaden();
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
(void)CheckDtypeSame(op_name, out, dout);
(void)CheckShapeSame(op_name, out, dout);
return out->Broaden();
}
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

Loading…
Cancel
Save