!8304 refine GraphKerneCSE by excluding fusion_type and unused attrs in primitive

From: @lingyunli63
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
pull/8304/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e27104763a

@ -17,15 +17,62 @@
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_info.h"
namespace mindspore {
namespace opt {
namespace {
bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) {
auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
if (main_primitive != nullptr && node_primitive != nullptr) {
if (main_primitive->name() != node_primitive->name()) {
return false;
}
auto main_attrs = main_primitive->attrs();
auto node_attrs = node_primitive->attrs();
std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"};
for (auto &attr : exclude_attrs) {
main_attrs.erase(attr);
node_attrs.erase(attr);
}
if (main_attrs.size() != node_attrs.size()) {
return false;
}
auto all = std::all_of(main_attrs.begin(), main_attrs.end(),
[&node_attrs](const std::pair<std::string, ValuePtr> &item) -> bool {
if (item.second == nullptr) {
return false;
}
auto iter = node_attrs.find(item.first);
if (iter == node_attrs.end()) {
return false;
}
return *item.second == *iter->second;
});
return all;
}
return *main->inputs()[0] == *node->inputs()[0];
}
} // namespace
bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsNodeInGraphKernel(main)) {
return BackendCSE::CheckEqualKernelBuildInfo(main, node);
}
auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
if (main_kernel_info == nullptr && node_kernel_info == nullptr) {
@ -43,8 +90,7 @@ bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, co
return false;
}
if (main_build_info->fusion_type() != node_build_info->fusion_type() ||
main_build_info->processor() != node_build_info->processor()) {
if (main_build_info->processor() != node_build_info->processor()) {
return false;
}
@ -53,6 +99,33 @@ bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, co
return false;
}
bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const {
auto c_main = main->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_main);
auto c_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_node);
if (!AnfAlgo::IsNodeInGraphKernel(c_main)) {
return BackendCSE::CheckEqualCnodeInputs(main, node);
}
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() != inp2.size()) {
return false;
}
for (size_t j = 1; j < inp1.size(); j++) {
auto inp1_j = inp1[j];
auto inp2_j = inp2[j];
MS_EXCEPTION_IF_NULL(inp1_j);
MS_EXCEPTION_IF_NULL(inp2_j);
if (!(*inp1_j == *inp2_j)) {
return false;
}
}
return IsCNodePrimitveEqual(c_main, c_node);
}
bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>();

@ -32,6 +32,7 @@ class GraphKernelBackendCSE : public BackendCSE {
GraphKernelBackendCSE() = default;
~GraphKernelBackendCSE() override = default;
bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override;
bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override;
};
} // namespace opt
} // namespace mindspore

@ -48,6 +48,28 @@ bool BackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNode
return false;
}
bool BackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const {
auto c_main = main->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_main);
auto c_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_node);
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() != inp2.size()) {
return false;
}
for (size_t j = 0; j < inp1.size(); j++) {
auto inp1_j = inp1[j];
auto inp2_j = inp2[j];
MS_EXCEPTION_IF_NULL(inp1_j);
MS_EXCEPTION_IF_NULL(inp2_j);
if (!(*inp1_j == *inp2_j)) {
return false;
}
}
return true;
}
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
@ -69,25 +91,7 @@ bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bo
if (!CheckEqualKernelBuildInfo(main, node)) {
return false;
}
auto c_main = main->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_main);
auto c_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_node);
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() != inp2.size()) {
return false;
}
for (size_t j = 0; j < inp1.size(); j++) {
auto inp1_j = inp1[j];
auto inp2_j = inp2[j];
MS_EXCEPTION_IF_NULL(inp1_j);
MS_EXCEPTION_IF_NULL(inp2_j);
if (!(*inp1_j == *inp2_j)) {
return false;
}
}
return true;
return CheckEqualCnodeInputs(main, node);
}
return false;
}

@ -31,6 +31,7 @@ class BackendCSE : public CSE {
public:
BackendCSE() = default;
~BackendCSE() override = default;
virtual bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
virtual bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const;
};

Loading…
Cancel
Save