|
|
|
@ -17,15 +17,62 @@
|
|
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "runtime/device/kernel_info.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) {
|
|
|
|
|
auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
|
|
|
|
|
auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
|
if (main_primitive != nullptr && node_primitive != nullptr) {
|
|
|
|
|
if (main_primitive->name() != node_primitive->name()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto main_attrs = main_primitive->attrs();
|
|
|
|
|
auto node_attrs = node_primitive->attrs();
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"};
|
|
|
|
|
for (auto &attr : exclude_attrs) {
|
|
|
|
|
main_attrs.erase(attr);
|
|
|
|
|
node_attrs.erase(attr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (main_attrs.size() != node_attrs.size()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto all = std::all_of(main_attrs.begin(), main_attrs.end(),
|
|
|
|
|
[&node_attrs](const std::pair<std::string, ValuePtr> &item) -> bool {
|
|
|
|
|
if (item.second == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto iter = node_attrs.find(item.first);
|
|
|
|
|
if (iter == node_attrs.end()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return *item.second == *iter->second;
|
|
|
|
|
});
|
|
|
|
|
return all;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return *main->inputs()[0] == *node->inputs()[0];
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(main);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
|
|
if (!AnfAlgo::IsNodeInGraphKernel(main)) {
|
|
|
|
|
return BackendCSE::CheckEqualKernelBuildInfo(main, node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
|
|
|
|
|
auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
|
|
|
|
if (main_kernel_info == nullptr && node_kernel_info == nullptr) {
|
|
|
|
@ -43,8 +90,7 @@ bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, co
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (main_build_info->fusion_type() != node_build_info->fusion_type() ||
|
|
|
|
|
main_build_info->processor() != node_build_info->processor()) {
|
|
|
|
|
if (main_build_info->processor() != node_build_info->processor()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -53,6 +99,33 @@ bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, co
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|
|
|
|
auto c_main = main->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_main);
|
|
|
|
|
auto c_node = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_node);
|
|
|
|
|
|
|
|
|
|
if (!AnfAlgo::IsNodeInGraphKernel(c_main)) {
|
|
|
|
|
return BackendCSE::CheckEqualCnodeInputs(main, node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto &inp1 = c_main->inputs();
|
|
|
|
|
const auto &inp2 = c_node->inputs();
|
|
|
|
|
if (inp1.size() != inp2.size()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = 1; j < inp1.size(); j++) {
|
|
|
|
|
auto inp1_j = inp1[j];
|
|
|
|
|
auto inp2_j = inp2[j];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp1_j);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp2_j);
|
|
|
|
|
if (!(*inp1_j == *inp2_j)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return IsCNodePrimitveEqual(c_main, c_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>();
|
|
|
|
|