!14548 change macro of infer function reg

From: @lianliguang
Reviewed-by: @ginfung,@chujinjin
Signed-off-by: @chujinjin
pull/14548/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit bb3a19363c

@ -30,45 +30,6 @@ namespace prim {
ValuePtr GetPythonOps(const std::string &op_name,
const std::string &module_name = "mindspore._extends.parse.standard_method",
bool use_signature = false);
// Primitives only used by frontend;
// Type introspection
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
// Other miscellaneous
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
// Structures
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
class UnpackGraphPrimitive : public Primitive {
public:
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)

@ -639,55 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
}
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref, value, [universal]
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
auto type = args_spec_list[0]->BuildType();
if (type->type_id() == kObjectTypeRefKey) {
return args_spec_list[1]->Broaden();
} else {
return args_spec_list[0];
}
}
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref/Tensor, universal
CheckArgsSize(primitive->name(), args_spec_list, 2);
auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
if (ref_abs != nullptr) {
// Return tensor value if input is Ref.
return ref_abs->CloneAsTensor();
}
return args_spec_list[0]->Broaden();
}
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
InferImplBroadcastGradientArgs);
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign);
REGISTER_PRIMITIVE_EVAL_IMPL(Load, prim::kPrimLoad, InferImplLoad);
REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs,
nullptr, false);
} // namespace abstract
} // namespace mindspore

@ -59,18 +59,6 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
class RegisterFrontendPrimitiveEvalHelper {
public:
RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
const StandardPrimitiveImplReg impl_reg{impl, false};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterFrontendPrimitiveEvalHelper() = default;
};
#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl)
} // namespace abstract
} // namespace mindspore

@ -308,6 +308,10 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

@ -577,5 +577,31 @@ AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &pr
abstract->set_value(value);
return abstract;
}
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref/Tensor, universal
CheckArgsSize(primitive->name(), args_spec_list, 2);
auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
if (ref_abs != nullptr) {
// Return tensor value if input is Ref.
return ref_abs->CloneAsTensor();
}
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref, value, [universal]
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
auto type = args_spec_list[0]->BuildType();
if (type->type_id() == kObjectTypeRefKey) {
return args_spec_list[1]->Broaden();
} else {
return args_spec_list[0];
}
}
} // namespace abstract
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -28,9 +28,13 @@ namespace mindspore {
namespace abstract {
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
struct StandardPrimitiveImplReg {
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
bool in_white_list_; // true if this Primitive in white list, else false.
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive
InferValueEvalImpl infer_value_func_; // infer value of primitive
// true means this primitive can be executed by vm backend else will be constant folded by frontend
bool in_white_list_;
};
using PrimitiveEvalImplMap =
@ -48,15 +52,17 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
class RegisterStandardPrimitiveEvalHelper {
public:
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
const StandardPrimitiveImplReg impl_reg{impl, true};
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl,
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterStandardPrimitiveEvalHelper() = default;
};
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl)
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \
static auto helper_##name = \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list)
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_

@ -541,6 +541,40 @@ inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType");
inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion");
inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf");
// Type introspection
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
// Other miscellaneous
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
// Structures
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)

@ -49,6 +49,5 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer);
REGISTER_PRIMITIVE_C(kNameAdd, Add);
} // namespace mindspore

@ -42,7 +42,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer);
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
} // namespace ops
} // namespace mindspore

@ -42,7 +42,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
}
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer);
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
} // namespace ops
} // namespace mindspore

@ -36,7 +36,7 @@ AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const Pri
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest);
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true);
AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
EXPECT_EQ(args_spec_list.size(), 3);
@ -45,7 +45,7 @@ AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, c
auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>();
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest);
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput, kPrimDynamicInputTest, InferImplDynamicInputTest, nullptr, true);
class TestAttrAndDynamicBackendInfer : public UT::Common {
public:
TestAttrAndDynamicBackendInfer() {}

Loading…
Cancel
Save