From 2974b906d35e1b01df92347f417acd9883098738 Mon Sep 17 00:00:00 2001 From: Kang Date: Tue, 16 Jun 2020 21:59:53 +0800 Subject: [PATCH] Optimization for opt --- mindspore/ccsrc/optimizer/irpass.cc | 14 +++---- mindspore/ccsrc/optimizer/irpass.h | 42 ++++--------------- .../optimizer/irpass/env_item_eliminate.h | 27 ++++++++++++ .../optimizer/irpass/gradient_eliminate.h | 15 ------- .../optimizer/irpass/incorporate_getitem.h | 25 +++++++++++ .../optimizer/irpass/special_op_eliminate.h | 5 ++- mindspore/ccsrc/optimizer/opt.cc | 13 +++++- mindspore/ccsrc/pipeline/pass.cc | 21 +++++----- tests/ut/cpp/optimizer/lib_test.cc | 4 +- 9 files changed, 93 insertions(+), 73 deletions(-) diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 13c6841604..a159dbfbce 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -51,8 +51,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", - {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType, - prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); + {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, + prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); @@ -72,9 +72,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode); // Env Item Eliminate + env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); - add_env_get_item_ = MakeSubstitution(AddEnvGetItem(), "add_env_get_item", prim::kPrimEnvGetItem); - env_get_set_item_ = MakeSubstitution(EnvGetSetItem(), "env_get_set_item", prim::kPrimEnvGetItem); incorporate_env_getitem_ = MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_switch_ = @@ -91,8 +90,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Gradient transforms expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); - stop_gradient_eliminate_ = - MakeSubstitution(StopGradientEliminater(), "stop_gradient_eliminate", prim::kPrimStopGradient); minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); // branch culling @@ -113,9 +110,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); // Incorporation - incorporate_getitem_ = MakeSubstitution(IncorporateGetitem(), "incorporate_getitem", prim::kPrimTupleGetItem); - incorporate_getitem_switch_ = - MakeSubstitution(IncorporateGetitemSwitch(), "incorporate_getitem_switch", prim::kPrimTupleGetItem); + incorporate_getitem_set_ = + MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index e834d69b69..d8e0cb67df 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -50,9 +50,8 @@ class OptimizeIRPassLib { SubstitutionPtr reset_defer_inline_; // Env Item Eliminate + SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr new_env_get_item_; - SubstitutionPtr add_env_get_item_; - SubstitutionPtr env_get_set_item_; SubstitutionPtr incorporate_env_getitem_; SubstitutionPtr incorporate_env_getitem_switch_; @@ -74,7 +73,6 @@ class OptimizeIRPassLib { // Gradient irpasses SubstitutionPtr expand_jprim_; - SubstitutionPtr stop_gradient_eliminate_; SubstitutionPtr minmaximum_grad_; // inline @@ -83,8 +81,7 @@ class OptimizeIRPassLib { SubstitutionPtr specialize_transform_; // Incorporation - SubstitutionPtr incorporate_getitem_; - SubstitutionPtr incorporate_getitem_switch_; + SubstitutionPtr incorporate_getitem_set_; SubstitutionPtr incorporate_call_; SubstitutionPtr incorporate_call_switch_; @@ -115,51 +112,30 @@ class InferenceOptPrepareLib { // predicate functions inline bool IsNode(const AnfNodePtr &) { return true; } -inline bool IsCNode(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} +inline bool IsCNode(const AnfNodePtr &node) { return node->isa(); } -inline bool IsVNode(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} +inline bool IsVNode(const AnfNodePtr &node) { return node->isa(); } -inline bool IsParam(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} +inline bool IsParam(const AnfNodePtr &node) { return node->isa(); } // Check if CNode Input 0 is Func Graph inline bool IsCNodeGraph(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { + if (!node->isa()) { return false; } auto inp0 = node->cast()->input(0); - if (IsValueNode(inp0)) { - return true; - } - return false; + return IsValueNode(inp0); } // Check if CNode Input 0 is CNode inline bool IsCNodeDup(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { + if (!node->isa()) { return false; } auto inp0 = node->cast()->input(0); - if (inp0 != nullptr && inp0->isa()) { - return true; - } - return false; + return (inp0 != nullptr) && inp0->isa(); } } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h index ce29b32d14..0f59c69fef 100644 --- a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h @@ -225,6 +225,33 @@ class EnvGetSetItem : public AnfVisitor { bool is_match_{false}; }; +class EnvGetItemEliminater { + public: + EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { + eliminaters_.emplace_back(new_env_get_item_); + eliminaters_.emplace_back(add_env_get_item_); + eliminaters_.emplace_back(env_get_set_item_); + } + ~EnvGetItemEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = eliminater(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + NewEnvGetItem new_env_get_item_; + AddEnvGetItem add_env_get_item_; + EnvGetSetItem env_get_set_item_; + std::vector eliminaters_{}; +}; + // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} class IncorporateEnvGetitem : public AnfVisitor { public: diff --git a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h index 651dc3a2f2..671d9bde49 100644 --- a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h @@ -55,21 +55,6 @@ class ExpandJPrim : public AnfVisitor { private: ValueNodePtr x_{nullptr}; }; - -// stop_gradient(x) ==> x -class StopGradientEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimStopGradient)(node); - return x_; - } - - void Visit(const AnfNodePtr &node) override { x_ = node; } - - private: - AnfNodePtr x_{nullptr}; -}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h index 77f3fa7b36..5b973dc334 100644 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h @@ -197,6 +197,31 @@ class IncorporateGetitemSwitch : public AnfVisitor { std::vector args_{}; internal::GetitemTransform getitem_transform_; }; + +class IncorporateGetitemSet { + public: + IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { + eliminaters_.emplace_back(incorporate_getitem_); + eliminaters_.emplace_back(incorporate_getitem_switch_); + } + ~IncorporateGetitemSet() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = eliminater(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + IncorporateGetitem incorporate_getitem_; + IncorporateGetitemSwitch incorporate_getitem_switch_; + std::vector eliminaters_{}; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index 93f24f0abe..cfefed40c7 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -35,12 +35,14 @@ class SpecialOpEliminater { public: SpecialOpEliminater() : insert_gradient_of_(prim::kPrimInsertGradientOf), + stop_gradient_(prim::kPrimStopGradient), hook_backward_(prim::kPrimHookBackward), print_shape_type_(prim::kPrimPrintShapeType), get_ref_value_(prim::kPrimGetRefValue), mirror_(prim::kPrimMirror), virtual_div_(prim::kPrimVirtualDiv) { eliminaters_.emplace_back(insert_gradient_of_); + eliminaters_.emplace_back(stop_gradient_); eliminaters_.emplace_back(hook_backward_); eliminaters_.emplace_back(print_shape_type_); eliminaters_.emplace_back(get_ref_value_); @@ -61,7 +63,8 @@ class SpecialOpEliminater { } private: - PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; + PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, + virtual_div_; std::vector eliminaters_{}; }; diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index c9f5803902..ef3a33b109 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -44,8 +44,17 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: return false; } + auto cnode = node->cast(); + auto inp0 = cnode->input(0); + auto prim0 = GetValueNode(inp0); + if (prim0 == nullptr) { + return false; + } + + auto hash = prim0->Hash(); + auto const &name = prim0->name(); for (auto &prim : prims) { - if (IsPrimitiveCNode(node, prim)) { + if (hash == prim->Hash() && name == prim->name()) { return true; } } @@ -171,7 +180,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo } #ifdef ENABLE_PROFILE - MsProfile::StatTime("opt.transform", GetTime() - start); + MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start); #endif return changes; } diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 7ee8a4ecb0..28ec23f7f2 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -79,16 +79,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Specialization irpass.specialize_transform_, - // Arithmetic simplifications - irpass.arithmetic_simplify_, - irpass.addn_zero_filter_, - irpass.adjust_all_reduce_mul_add_, - // Miscellaneous irpass.item_tuple_eliminate_, - irpass.env_get_set_item_, - irpass.new_env_get_item_, - irpass.add_env_get_item_, + irpass.env_get_item_eliminate_, irpass.cast_eliminate_, irpass.reshape_eliminate_, irpass.reduce_eliminate_, @@ -96,13 +89,20 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.transpose_eliminate_, irpass.minmaximum_grad_, irpass.get_make_ref_eliminate_, + + // Arithmetic simplifications + irpass.arithmetic_simplify_, + irpass.addn_zero_filter_, + irpass.adjust_all_reduce_mul_add_, + + // Safe inlining + irpass.inline_, }); opt::OptPassConfig a_2 = opt::OptPassConfig({ irpass.merge_addn_, irpass.float_tuple_getitem_switch_, irpass.float_env_getitem_switch_, - irpass.incorporate_getitem_, - irpass.incorporate_getitem_switch_, + irpass.incorporate_getitem_set_, irpass.incorporate_call_, irpass.incorporate_call_switch_, irpass.incorporate_env_getitem_, @@ -145,7 +145,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, - irpass.stop_gradient_eliminate_, irpass.get_make_ref_eliminate_, }); opt::OptPassConfig b_2 = opt::OptPassConfig({ diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 804a1f3aa3..3febf049a6 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -401,7 +401,7 @@ TEST_F(TestOptLib, test_incorporate_getitem) { FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after1"); FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after2"); - auto patterns = std::vector({irpass.incorporate_getitem_}); + auto patterns = std::vector({irpass.incorporate_getitem_set_}); ASSERT_TRUE(CheckOpt(before1, after1, patterns)); ASSERT_TRUE(CheckOpt(before2, after2, patterns)); @@ -411,7 +411,7 @@ TEST_F(TestOptLib, test_incorporate_getitem_through_switch) { FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "before"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "after"); - auto patterns = std::vector({irpass.incorporate_getitem_switch_}); + auto patterns = std::vector({irpass.incorporate_getitem_set_}); ASSERT_TRUE(CheckOpt(before, after, patterns)); }