diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 3b49bfbaf1..6932dfafd7 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -30,45 +30,6 @@ namespace prim { ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name = "mindspore._extends.parse.standard_method", bool use_signature = false); - -// Primitives only used by frontend; -// Type introspection -inline const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); -inline const PrimitivePtr kPrimHasType = std::make_shared("hastype"); - -inline const PrimitivePtr kPrimResolve = std::make_shared("resolve"); -inline const PrimitivePtr kPrimEmbed = std::make_shared("embed"); -inline const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); -inline const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); - -// Other miscellaneous -inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); -inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); -inline const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); -inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); -inline const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); - -// Structures - -inline const PrimitivePtr kPrimListMap = std::make_shared("list_map"); -inline const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); -inline const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); -inline const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); -inline const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); -inline const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); -inline const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); -inline const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); -inline const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); -inline const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); -inline const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); -inline const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); -inline const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); -inline const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); - -inline const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); - -inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); - class UnpackGraphPrimitive : public Primitive { public: explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index b0c275354e..a463fd0195 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -639,55 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt return std::make_shared(cls->tag(), abs_attributes, cls->methods()); } - -AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: Ref, value, [universal] - CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); - - MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; - auto type = args_spec_list[0]->BuildType(); - if (type->type_id() == kObjectTypeRefKey) { - return args_spec_list[1]->Broaden(); - } else { - return args_spec_list[0]; - } -} - -AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: Ref/Tensor, universal - CheckArgsSize(primitive->name(), args_spec_list, 2); - auto ref_abs = dyn_cast(args_spec_list[0]); - if (ref_abs != nullptr) { - // Return tensor value if input is Ref. - return ref_abs->CloneAsTensor(); - } - return args_spec_list[0]->Broaden(); -} - -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); -REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, - InferImplBroadcastGradientArgs); -REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign); -REGISTER_PRIMITIVE_EVAL_IMPL(Load, prim::kPrimLoad, InferImplLoad); +REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs, + nullptr, false); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h index 6ddacd199d..33ef55ef49 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h @@ -59,18 +59,6 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); - -class RegisterFrontendPrimitiveEvalHelper { - public: - RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { - const StandardPrimitiveImplReg impl_reg{impl, false}; - RegisterStandardPrimitiveImpl(primitive, impl_reg); - } - ~RegisterFrontendPrimitiveEvalHelper() = default; -}; - -#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ - static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl) } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 8eb6a6d05b..e8089e4f34 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -308,6 +308,10 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 7a15d3a067..37688918b8 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -577,5 +577,31 @@ AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &pr abstract->set_value(value); return abstract; } + +AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: Ref/Tensor, universal + CheckArgsSize(primitive->name(), args_spec_list, 2); + auto ref_abs = dyn_cast(args_spec_list[0]); + if (ref_abs != nullptr) { + // Return tensor value if input is Ref. + return ref_abs->CloneAsTensor(); + } + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: Ref, value, [universal] + CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); + + MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; + auto type = args_spec_list[0]->BuildType(); + if (type->type_id() == kObjectTypeRefKey) { + return args_spec_list[1]->Broaden(); + } else { + return args_spec_list[0]; + } +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index d0cf64dadd..913e5a6181 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -55,158 +55,161 @@ std::vector GetDependsFormMap(const CNodePtr &cnode) { PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { static PrimitiveEvalImplMap prim_eval_implement_map = { // Statements - {prim::kPrimReturn, {InferImplReturn, true}}, - {prim::kPrimSwitch, {InferImplSwitch, true}}, - {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, - {prim::kPrimIs_, {InferImplIs_, true}}, - {prim::kPrimIsNot, {InferImplIsNot, true}}, - {prim::kPrimInDict, {InferImplInDict, true}}, - {prim::kPrimNotInDict, {InferImplNotInDict, true}}, - {prim::kPrimIsConsant, {InferImplIsConstant, true}}, + {prim::kPrimReturn, {InferImplReturn, nullptr, true}}, + {prim::kPrimSwitch, {InferImplSwitch, nullptr, true}}, + {prim::kPrimSwitchLayer, {InferImplSwitchLayer, nullptr, true}}, + {prim::kPrimIs_, {InferImplIs_, nullptr, true}}, + {prim::kPrimIsNot, {InferImplIsNot, nullptr, true}}, + {prim::kPrimInDict, {InferImplInDict, nullptr, true}}, + {prim::kPrimNotInDict, {InferImplNotInDict, nullptr, true}}, + {prim::kPrimIsConsant, {InferImplIsConstant, nullptr, true}}, // Maths - {prim::kPrimSquare, {InferImplSquare, true}}, - {prim::kPrimMatMul, {InferImplMatMul, true}}, - {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, - {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimSqrt, {InferImplSqrt, true}}, + {prim::kPrimSquare, {InferImplSquare, nullptr, true}}, + {prim::kPrimMatMul, {InferImplMatMul, nullptr, true}}, + {prim::kPrimBatchMatMul, {InferImplBatchMatMul, nullptr, true}}, + {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, nullptr, true}}, + {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, nullptr, true}}, + {prim::kPrimSqrt, {InferImplSqrt, nullptr, true}}, // Array - {prim::kPrimRange, {InferImplRange, true}}, - {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, - {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, - {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimUnique, {InferImplUnique, true}}, - {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, - {prim::kPrimGather, {InferImplGatherV2, true}}, - {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, - {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, - {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, - {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, - {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, - {prim::kPrimSubAndFilter, {InferImplSubAndFilter, true}}, - {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, - {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, - {prim::kPrimDynamicAssign, {InferImplDynamicAssign, true}}, - {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, - {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, - {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, - {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, - {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, - {prim::kPrimMapUniform, {InferImplMapUniform, true}}, - {prim::kPrimSplit, {InferImplSplit, true}}, - {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, + {prim::kPrimRange, {InferImplRange, nullptr, true}}, + {prim::kPrimScalarToArray, {InferImplScalarToArray, nullptr, true}}, + {prim::kPrimArrayToScalar, {InferImplArrayToScalar, nullptr, true}}, + {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}}, + {prim::kPrimUnique, {InferImplUnique, nullptr, true}}, + {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}}, + {prim::kPrimGather, {InferImplGatherV2, nullptr, true}}, + {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}}, + {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}}, + {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}}, + {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}}, + {prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}}, + {prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}}, + {prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}}, + {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}}, + {prim::kPrimDynamicAssign, {InferImplDynamicAssign, nullptr, true}}, + {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}}, + {prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}}, + {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}}, + {prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}}, + {prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}}, + {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}}, + {prim::kPrimSplit, {InferImplSplit, nullptr, true}}, + {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}}, // Structure - {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, - {prim::kPrimMakeList, {InferImplMakeList, true}}, - {prim::kPrimMakeDict, {InferImplMakeDict, true}}, - {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, - {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, - {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, - {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, - {prim::kPrimListGetItem, {InferImplListGetItem, true}}, - {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, - {prim::kPrimListSetItem, {InferImplListSetItem, true}}, - {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, - {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, - {prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}}, - {prim::kPrimDictGetValues, {InferImplDictGetValues, true}}, - {prim::kPrimListAppend, {InferImplListAppend, true}}, - {prim::kPrimTupleLen, {InferImplTupleLen, true}}, - {prim::kPrimListLen, {InferImplListLen, true}}, - {prim::kPrimArrayLen, {InferImplArrayLen, true}}, + {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}}, + {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}}, + {prim::kPrimMakeDict, {InferImplMakeDict, nullptr, true}}, + {prim::kPrimMakeSlice, {InferImplMakeSlice, nullptr, true}}, + {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, nullptr, true}}, + {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, nullptr, true}}, + {prim::kPrimTupleGetItem, {InferImplTupleGetItem, nullptr, true}}, + {prim::kPrimListGetItem, {InferImplListGetItem, nullptr, true}}, + {prim::kPrimTupleSetItem, {InferImplTupleSetItem, nullptr, true}}, + {prim::kPrimListSetItem, {InferImplListSetItem, nullptr, true}}, + {prim::kPrimDictGetItem, {InferImplDictGetItem, nullptr, true}}, + {prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}}, + {prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}}, + {prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}}, + {prim::kPrimListAppend, {InferImplListAppend, nullptr, true}}, + {prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}}, + {prim::kPrimListLen, {InferImplListLen, nullptr, true}}, + {prim::kPrimArrayLen, {InferImplArrayLen, nullptr, true}}, // NN - {prim::kPrimPooling, {InferImplPooling, true}}, - {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, - {prim::kPrimBatchNorm, {InferImplBatchNorm, true}}, - {prim::kPrimReluGrad, {InferImplReluGrad, true}}, - {prim::kPrimConv2D, {InferImplConv2D, true}}, - {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, - {prim::kPrimRelu, {InferImplRelu, true}}, - {prim::kPrimRelu6, {InferImplRelu, true}}, - {prim::kPrimZerosLike, {InferImplZerosLike, true}}, - {prim::kPrimBpropCut, {InferImplBpropCut, true}}, - {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, - {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, - {prim::kPrimDropout, {InferImplDropout, true}}, - {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, - {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, - {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, - {prim::kPrimSGD, {InferImplSGD, true}}, - {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}}, + {prim::kPrimPooling, {InferImplPooling, nullptr, true}}, + {prim::kPrimPoolingGrad, {InferImplPoolingGrad, nullptr, true}}, + {prim::kPrimBatchNorm, {InferImplBatchNorm, nullptr, true}}, + {prim::kPrimReluGrad, {InferImplReluGrad, nullptr, true}}, + {prim::kPrimConv2D, {InferImplConv2D, nullptr, true}}, + {prim::kPrimBiasAdd, {InferImplBiasAdd, nullptr, true}}, + {prim::kPrimRelu, {InferImplRelu, nullptr, true}}, + {prim::kPrimRelu6, {InferImplRelu, nullptr, true}}, + {prim::kPrimZerosLike, {InferImplZerosLike, nullptr, true}}, + {prim::kPrimBpropCut, {InferImplBpropCut, nullptr, true}}, + {prim::kPrimLayerNorm, {InferImplLayerNorm, nullptr, true}}, + {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, nullptr, true}}, + {prim::kPrimDropout, {InferImplDropout, nullptr, true}}, + {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, nullptr, true}}, + {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, nullptr, true}}, + {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}}, + {prim::kPrimSGD, {InferImplSGD, nullptr, true}}, + {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}}, // Others - {prim::kPrimIdentity, {InferImplIdentity, true}}, + {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}}, + {prim::kPrimLoad, {InferImplLoad, nullptr, true}}, + {prim::kPrimAssign, {InferImplAssign, nullptr, true}}, // Set impl to null as it will use PartialEvaluator; - {prim::kPrimPartial, {nullptr, true}}, - {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, - {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, - {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, - {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, - {prim::kPrimMakeRef, {InferImplMakeRef, true}}, - {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, - {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, - {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, - {prim::kPrimDepend, {InferImplDepend, true}}, - {prim::kPrimUpdateState, {InferImplUpdateState, true}}, - {prim::kPrimControlDepend, {InferImplControlDepend, true}}, + {prim::kPrimPartial, {nullptr, nullptr, true}}, + {prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}}, + {prim::kPrimEnvSetItem, {InferImplEnvSetItem, nullptr, true}}, + {prim::kPrimEnvAdd, {InferImplEnvAdd, nullptr, true}}, + {prim::kPrimMakeRefKey, {InferImplMakeRefKey, nullptr, true}}, + {prim::kPrimMakeRef, {InferImplMakeRef, nullptr, true}}, + {prim::kPrimGetRefKey, {InferImplGetRefKey, nullptr, true}}, + {prim::kPrimGetRefValue, {InferImplGetRefValue, nullptr, true}}, + {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, + {prim::kPrimDepend, {InferImplDepend, nullptr, true}}, + {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, + {prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}}, // Debug - {prim::kPrimDebug, {InferImplDebug, true}}, + {prim::kPrimDebug, {InferImplDebug, nullptr, true}}, // Dynamic shape testing - {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, true}}, + {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, nullptr, true}}, // SparseTensor - {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, - {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, - {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, - {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, + {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, nullptr, true}}, + {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, nullptr, true}}, + {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, nullptr, true}}, + {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}}, // RowTensor - {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, - {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, - {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, - {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, - {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, + {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}}, + + {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}}, + {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}}, + {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}}, + {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}}, // Comm Ops - {prim::kPrimAllSwap, {InferImplAllSwap, true}}, - {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, + {prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}}, + {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}}, }; return prim_eval_implement_map; } PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { static PrimitiveEvalImplMap prim_backend_eval_implement_map = { - {prim::kPrimMul, {InferImplMul, true}}, - {prim::kPrimAdd, {InferImplAdd, true}}, - {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, - {prim::kPrimSub, {InferImplSub, true}}, - {prim::kPrimEqual, {InferImplEqual, true}}, - {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, - {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, - {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, - {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, - {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, - {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, - {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, - {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, - {prim::kPrimCast, {InferImplCast, true}}, - {prim::kPrimExpandDims, {InferImplExpandDims, true}}, - {prim::kPrimAllReduce, {InferImplAllReduce, true}}, - {prim::kPrimBroadcast, {InferImplBroadcast, true}}, - {prim::kPrimAllGather, {InferImplAllGather, true}}, - {prim::kPrimMinimum, {InferImplMinimum, true}}, - {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, - {prim::kPrimLinSpace, {InferImplLinSpace, true}}, - {prim::kPrimAddN, {InferImplAddN, true}}, + {prim::kPrimMul, {InferImplMul, nullptr, true}}, + {prim::kPrimAdd, {InferImplAdd, nullptr, true}}, + {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}}, + {prim::kPrimSub, {InferImplSub, nullptr, true}}, + {prim::kPrimEqual, {InferImplEqual, nullptr, true}}, + {prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}}, + {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}}, + {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}}, + {prim::kPrimReduceAny, {InferImplReduceFunc, nullptr, true}}, + {prim::kPrimReduceMax, {InferImplReduceFunc, nullptr, true}}, + {prim::kPrimReduceMin, {InferImplReduceFunc, nullptr, true}}, + {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}}, + {prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}}, + {prim::kPrimCast, {InferImplCast, nullptr, true}}, + {prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}}, + {prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}}, + {prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}}, + {prim::kPrimAllGather, {InferImplAllGather, nullptr, true}}, + {prim::kPrimMinimum, {InferImplMinimum, nullptr, true}}, + {prim::kPrimDivNoNan, {InferImplDivNoNan, nullptr, true}}, + {prim::kPrimLinSpace, {InferImplLinSpace, nullptr, true}}, + {prim::kPrimAddN, {InferImplAddN, nullptr, true}}, - {prim::kPrimLess, {InferImplLess, true}}, - {prim::kPrimStack, {InferImplStack, true}}, - {prim::kPrimPad, {InferImplPad, true}}, - {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, - {prim::kPrimDiv, {InferImplDiv, true}}, - {prim::kPrimRealDiv, {InferImplRealDiv, true}}, - {prim::kPrimShape, {InferImplShape, false}}, - {prim::kPrimTranspose, {InferImplTranspose, true}}, - {prim::kPrimReshape, {InferImplReshape, true}}, - {prim::kPrimConcat, {InferImplConcat, true}}, - {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, - {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, + {prim::kPrimLess, {InferImplLess, nullptr, true}}, + {prim::kPrimStack, {InferImplStack, nullptr, true}}, + {prim::kPrimPad, {InferImplPad, nullptr, true}}, + {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}}, + {prim::kPrimDiv, {InferImplDiv, nullptr, true}}, + {prim::kPrimRealDiv, {InferImplRealDiv, nullptr, true}}, + {prim::kPrimShape, {InferImplShape, nullptr, false}}, + {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}}, + {prim::kPrimReshape, {InferImplReshape, nullptr, true}}, + {prim::kPrimConcat, {InferImplConcat, nullptr, true}}, + {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}}, + {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}}, }; return prim_backend_eval_implement_map; } diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index f175c2e30c..c4b6d984c9 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -28,9 +28,13 @@ namespace mindspore { namespace abstract { using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &); +using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); + struct StandardPrimitiveImplReg { - StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. - bool in_white_list_; // true if this Primitive in white list, else false. + StandardPrimitiveEvalImpl impl_; // Implement function of Primitive + InferValueEvalImpl infer_value_func_; // infer value of primitive + // true means this primitive can be executed by vm backend else will be constant folded by frontend + bool in_white_list_; }; using PrimitiveEvalImplMap = @@ -48,15 +52,17 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard class RegisterStandardPrimitiveEvalHelper { public: - RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { - const StandardPrimitiveImplReg impl_reg{impl, true}; + RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl, + const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { + const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list}; RegisterStandardPrimitiveImpl(primitive, impl_reg); } ~RegisterStandardPrimitiveEvalHelper() = default; }; -#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ - static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl) +#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \ + static auto helper_##name = \ + abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list) } // namespace abstract } // namespace mindspore #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 3e275a5d68..d4fd66d8e4 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -539,6 +539,40 @@ inline const PrimitivePtr kPrimReduceFusion = std::make_shared("Reduc inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared("LayerNormFusion"); inline const PrimitivePtr kPrimDType = std::make_shared("DType"); +// Type introspection +inline const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); +inline const PrimitivePtr kPrimHasType = std::make_shared("hastype"); + +inline const PrimitivePtr kPrimResolve = std::make_shared("resolve"); +inline const PrimitivePtr kPrimEmbed = std::make_shared("embed"); +inline const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); +inline const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); + +// Other miscellaneous +inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); +inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); +inline const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); +inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); +inline const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); + +// Structures +inline const PrimitivePtr kPrimListMap = std::make_shared("list_map"); +inline const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); +inline const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); +inline const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); +inline const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); +inline const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); +inline const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); +inline const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); +inline const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); +inline const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); +inline const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); +inline const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); +inline const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); +inline const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); +inline const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); +inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); + class DoSignaturePrimitive : public Primitive { public: explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) diff --git a/mindspore/core/c_ops/add.cc b/mindspore/core/c_ops/add.cc index bb0c06723e..54fb29abf0 100644 --- a/mindspore/core/c_ops/add.cc +++ b/mindspore/core/c_ops/add.cc @@ -49,6 +49,5 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer); REGISTER_PRIMITIVE_C(kNameAdd, Add); } // namespace mindspore diff --git a/mindspore/core/ops/scalar_summary.cc b/mindspore/core/ops/scalar_summary.cc index b0024d7d42..342c4d4a43 100644 --- a/mindspore/core/ops/scalar_summary.cc +++ b/mindspore/core/ops/scalar_summary.cc @@ -42,7 +42,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); return std::make_shared(kInt32, std::make_shared(ShapeVector(1))); } -REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer); +REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/tensor_summary.cc b/mindspore/core/ops/tensor_summary.cc index 77ae216f24..4b17b528e8 100644 --- a/mindspore/core/ops/tensor_summary.cc +++ b/mindspore/core/ops/tensor_summary.cc @@ -42,7 +42,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); return std::make_shared(kInt32, std::make_shared(ShapeVector(1))); } -REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer); +REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); } // namespace ops } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc b/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc index 8cb97d3a5f..ff9f191c6a 100644 --- a/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc +++ b/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc @@ -36,7 +36,7 @@ AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const Pri EXPECT_EQ(args_spec_list[1]->isa(), true); return args_spec_list[0]; } -REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest); +REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true); AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { EXPECT_EQ(args_spec_list.size(), 3); @@ -45,7 +45,7 @@ AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, c auto item = args_spec_list[1]->cast(); return args_spec_list[0]; } -REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest); +REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput, kPrimDynamicInputTest, InferImplDynamicInputTest, nullptr, true); class TestAttrAndDynamicBackendInfer : public UT::Common { public: TestAttrAndDynamicBackendInfer() {}