From 13126653ec054d77444f42a4a2032af64800e0bb Mon Sep 17 00:00:00 2001 From: tronzhang <6517937+tronzhang@user.noreply.gitee.com> Date: Thu, 3 Dec 2020 15:27:36 +0800 Subject: [PATCH] process cast when activate graph kernel in amp --- .../optimizer/graph_kernel/graph_kernel_cse.cc | 11 +++++++---- .../optimizer/graph_kernel/graph_kernel_cse.h | 18 +++++++++++++----- .../graph_kernel/shape_ops_splitter.cc | 2 +- mindspore/ccsrc/backend/session/gpu_session.cc | 10 +++++++--- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc index b401d1fa44..12ffe5ee75 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc @@ -16,6 +16,7 @@ #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" +#include #include #include #include @@ -26,13 +27,15 @@ namespace mindspore { namespace opt { namespace { -bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) { +bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std::vector &black_list) { auto main_primitive = AnfAlgo::GetCNodePrimitive(main); auto node_primitive = AnfAlgo::GetCNodePrimitive(node); if (main_primitive != nullptr && node_primitive != nullptr) { // Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op // alone can prevent some redundant output case (input -> reshape -> output). - if (main_primitive->name() != node_primitive->name() || IsPrimitiveCNode(node, prim::kPrimReshape)) { + if (main_primitive->name() != node_primitive->name() || + std::any_of(black_list.begin(), black_list.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) { return false; } @@ -125,12 +128,12 @@ bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const return false; } } - return IsCNodePrimitveEqual(c_main, c_node); + return IsCNodePrimitveEqual(c_main, c_node, black_list_); } bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - auto graphkernel_backend_cse = std::make_shared(); + auto graphkernel_backend_cse = std::make_shared(black_list_); return graphkernel_backend_cse->Cse(func_graph, func_graph->manager()); } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h index 9336b297fb..cafc21c2b5 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h @@ -13,27 +13,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_ +#include #include "backend/optimizer/pass/common_subexpression_elimination.h" namespace mindspore { namespace opt { class GraphKernelCSE : public Pass { public: - GraphKernelCSE() : Pass("graph_kernel_cse") {} + explicit GraphKernelCSE(const std::vector &black_list = {}) + : Pass("graph_kernel_cse"), black_list_(black_list) {} ~GraphKernelCSE() override = default; bool Run(const FuncGraphPtr &func_graph) override; + + private: + std::vector black_list_; }; class GraphKernelBackendCSE : public BackendCSE { public: - GraphKernelBackendCSE() = default; + explicit GraphKernelBackendCSE(const std::vector &black_list = {}) : black_list_(black_list) {} ~GraphKernelBackendCSE() override = default; bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override; bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override; + + private: + std::vector black_list_; }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc index 21a31e9719..006d0e4961 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc @@ -34,7 +34,7 @@ namespace mindspore { namespace opt { namespace { bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) { - std::vector shape_ops = {prim::kPrimReshape}; + std::vector shape_ops = {prim::kPrimReshape, prim::kPrimCast}; auto &users = mng->node_users(); return std::any_of(shape_ops.begin(), shape_ops.end(), [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) && diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index cda7a03848..d70a570d4f 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -120,7 +120,9 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared("cast_all")); + if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { + pm->AddPass(std::make_shared("cast_all")); + } pm->AddPass(std::make_shared("combine_momentum")); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -165,15 +167,17 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ } auto optimizer = std::make_shared(); auto pm = std::make_shared("graph_kernel_pm"); + std::vector black_list = {prim::kPrimReshape, prim::kPrimCast}; pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared(black_list)); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared(black_list)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); // After Simplify and Splitter, a lot of redundant getitem/maketuple // will be exposed, use GetitemTuple Pass to delete them. pm->AddPass(std::make_shared());