From 9c33d3ffe0e18d90573f8f86fb264eb297ff75ca Mon Sep 17 00:00:00 2001 From: liubuyu Date: Mon, 14 Dec 2020 14:39:01 +0800 Subject: [PATCH] add infer func for fused_sparse_adam --- mindspore/core/abstract/infer_functions.h | 2 ++ mindspore/core/abstract/prim_nn.cc | 15 +++++++++++++++ mindspore/core/abstract/primitive_infer_map.cc | 1 + mindspore/core/base/core_ops.h | 1 + 4 files changed, 19 insertions(+) diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index e2642da41f..eb2fb97107 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -49,6 +49,8 @@ AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitiveP const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 40d25a836a..4f7cced440 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -252,6 +252,21 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr return out->Broaden(); } +AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // the output is useless, so we dont have to focus on the output shape + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + MS_EXCEPTION_IF_NULL(args_spec_list[3]); + + auto dx = args_spec_list[1]->Broaden(); + auto dscale = args_spec_list[2]->Broaden(); + auto dbias = args_spec_list[3]->Broaden(); + + AbstractBasePtrList rets = {dx, dscale, dbias}; + return std::make_shared(rets); +} + AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: three tensors(doutput, input, filters). diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index baacb9ce91..c45cbd08d4 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -101,6 +101,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimPooling, {InferImplPooling, true}}, {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, + {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, {prim::kPrimReluGrad, {InferImplReluGrad, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index cd8be927ad..0a0f3d0796 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -140,6 +140,7 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("AvgPool"); inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared("AvgPoolGradVm"); +inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared("FusedSparseAdam"); inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared("FusedBatchNormEx"); inline const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D");