|
|
|
@ -51,8 +51,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|
|
|
|
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
|
|
|
|
|
special_op_eliminate_ =
|
|
|
|
|
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
|
|
|
|
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
|
|
|
|
|
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
|
|
|
|
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
|
|
|
|
|
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
|
|
|
|
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
|
|
|
|
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
|
|
|
|
|
|
|
|
|
@ -72,9 +72,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|
|
|
|
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
|
|
|
|
|
|
|
|
|
// Env Item Eliminate
|
|
|
|
|
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
|
|
|
|
|
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem);
|
|
|
|
|
add_env_get_item_ = MakeSubstitution(AddEnvGetItem(), "add_env_get_item", prim::kPrimEnvGetItem);
|
|
|
|
|
env_get_set_item_ = MakeSubstitution(EnvGetSetItem(), "env_get_set_item", prim::kPrimEnvGetItem);
|
|
|
|
|
incorporate_env_getitem_ =
|
|
|
|
|
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
|
|
|
|
incorporate_env_getitem_switch_ =
|
|
|
|
@ -91,8 +90,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|
|
|
|
|
|
|
|
|
// Gradient transforms
|
|
|
|
|
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
|
|
|
|
|
stop_gradient_eliminate_ =
|
|
|
|
|
MakeSubstitution(StopGradientEliminater(), "stop_gradient_eliminate", prim::kPrimStopGradient);
|
|
|
|
|
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
|
|
|
|
|
|
|
|
|
|
// branch culling
|
|
|
|
@ -113,9 +110,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|
|
|
|
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph);
|
|
|
|
|
|
|
|
|
|
// Incorporation
|
|
|
|
|
incorporate_getitem_ = MakeSubstitution(IncorporateGetitem(), "incorporate_getitem", prim::kPrimTupleGetItem);
|
|
|
|
|
incorporate_getitem_switch_ =
|
|
|
|
|
MakeSubstitution(IncorporateGetitemSwitch(), "incorporate_getitem_switch", prim::kPrimTupleGetItem);
|
|
|
|
|
incorporate_getitem_set_ =
|
|
|
|
|
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
|
|
|
|
|
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
|
|
|
|
|
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
|
|
|
|
|
|
|
|
|
|