|
|
|
@ -252,6 +252,21 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr
|
|
|
|
|
return out->Broaden();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
// the output is useless, so we dont have to focus on the output shape
|
|
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list[3]);
|
|
|
|
|
|
|
|
|
|
auto dx = args_spec_list[1]->Broaden();
|
|
|
|
|
auto dscale = args_spec_list[2]->Broaden();
|
|
|
|
|
auto dbias = args_spec_list[3]->Broaden();
|
|
|
|
|
|
|
|
|
|
AbstractBasePtrList rets = {dx, dscale, dbias};
|
|
|
|
|
return std::make_shared<AbstractTuple>(rets);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
// Inputs: three tensors(doutput, input, filters).
|
|
|
|
|