|
|
|
@ -242,8 +242,14 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv
|
|
|
|
|
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
// Inputs: two tensors(y_backprop, x).
|
|
|
|
|
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
|
|
|
|
return args_spec_list[1]->Broaden();
|
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|
|
CheckArgsSize(op_name, args_spec_list, 2);
|
|
|
|
|
auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
|
|
|
|
(void)CheckDtypeSame(op_name, out, dout);
|
|
|
|
|
(void)CheckShapeSame(op_name, out, dout);
|
|
|
|
|
|
|
|
|
|
return out->Broaden();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|