From b3d76c6e3ed66a9af87c9e8645eb3614fa3e3546 Mon Sep 17 00:00:00 2001 From: lingyunli63 Date: Thu, 5 Nov 2020 16:16:57 +0800 Subject: [PATCH] exclude unused attrs and fusion_type in cse cmp --- .../graph_kernel/graph_kernel_cse.cc | 77 ++++++++++++++++++- .../optimizer/graph_kernel/graph_kernel_cse.h | 1 + .../pass/common_subexpression_elimination.cc | 42 +++++----- .../pass/common_subexpression_elimination.h | 1 + 4 files changed, 100 insertions(+), 21 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc index 9f04d97888..97f38f6587 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.cc @@ -17,15 +17,62 @@ #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" #include +#include +#include +#include #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" namespace mindspore { namespace opt { +namespace { +bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) { + auto main_primitive = AnfAlgo::GetCNodePrimitive(main); + auto node_primitive = AnfAlgo::GetCNodePrimitive(node); + if (main_primitive != nullptr && node_primitive != nullptr) { + if (main_primitive->name() != node_primitive->name()) { + return false; + } + + auto main_attrs = main_primitive->attrs(); + auto node_attrs = node_primitive->attrs(); + + std::vector exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"}; + for (auto &attr : exclude_attrs) { + main_attrs.erase(attr); + node_attrs.erase(attr); + } + + if (main_attrs.size() != node_attrs.size()) { + return false; + } + + auto all = std::all_of(main_attrs.begin(), main_attrs.end(), + [&node_attrs](const std::pair &item) -> bool { + if (item.second == nullptr) { + return false; + } + auto iter = node_attrs.find(item.first); + if (iter == node_attrs.end()) { + return false; + } + return *item.second == *iter->second; + }); + return all; + } + + return *main->inputs()[0] == *node->inputs()[0]; +} +} // namespace bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); + + if (!AnfAlgo::IsNodeInGraphKernel(main)) { + return BackendCSE::CheckEqualKernelBuildInfo(main, node); + } + auto main_kernel_info = dynamic_cast(main->kernel_info()); auto node_kernel_info = dynamic_cast(node->kernel_info()); if (main_kernel_info == nullptr && node_kernel_info == nullptr) { @@ -43,8 +90,7 @@ bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, co return false; } - if (main_build_info->fusion_type() != node_build_info->fusion_type() || - main_build_info->processor() != node_build_info->processor()) { + if (main_build_info->processor() != node_build_info->processor()) { return false; } @@ -53,6 +99,33 @@ bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, co return false; } +bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const { + auto c_main = main->cast(); + MS_EXCEPTION_IF_NULL(c_main); + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + + if (!AnfAlgo::IsNodeInGraphKernel(c_main)) { + return BackendCSE::CheckEqualCnodeInputs(main, node); + } + + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() != inp2.size()) { + return false; + } + for (size_t j = 1; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + return false; + } + } + return IsCNodePrimitveEqual(c_main, c_node); +} + bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); auto graphkernel_backend_cse = std::make_shared(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h index b37b301bcf..9336b297fb 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cse.h @@ -32,6 +32,7 @@ class GraphKernelBackendCSE : public BackendCSE { GraphKernelBackendCSE() = default; ~GraphKernelBackendCSE() override = default; bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override; + bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc index 95eecc500e..c469ca6c0c 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc @@ -48,6 +48,28 @@ bool BackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNode return false; } +bool BackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const { + auto c_main = main->cast(); + MS_EXCEPTION_IF_NULL(c_main); + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() != inp2.size()) { + return false; + } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + return false; + } + } + return true; +} + bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); @@ -69,25 +91,7 @@ bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bo if (!CheckEqualKernelBuildInfo(main, node)) { return false; } - auto c_main = main->cast(); - MS_EXCEPTION_IF_NULL(c_main); - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - const auto &inp1 = c_main->inputs(); - const auto &inp2 = c_node->inputs(); - if (inp1.size() != inp2.size()) { - return false; - } - for (size_t j = 0; j < inp1.size(); j++) { - auto inp1_j = inp1[j]; - auto inp2_j = inp2[j]; - MS_EXCEPTION_IF_NULL(inp1_j); - MS_EXCEPTION_IF_NULL(inp2_j); - if (!(*inp1_j == *inp2_j)) { - return false; - } - } - return true; + return CheckEqualCnodeInputs(main, node); } return false; } diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h index a5ba14b63a..6a50fb0c92 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h @@ -31,6 +31,7 @@ class BackendCSE : public CSE { public: BackendCSE() = default; ~BackendCSE() override = default; + virtual bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const; bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; virtual bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const; };