|
|
|
@ -16,6 +16,8 @@
|
|
|
|
|
#include "backend/optimizer/pass/common_subexpression_elimination.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "runtime/device/kernel_info.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "utils/flags.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
@ -33,48 +35,60 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasSideEffectAttr(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(GRAPH_FLAG_SIDE_EFFECT, cnode)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return AnfAlgo::GetNodeAttr<bool>(cnode, GRAPH_FLAG_SIDE_EFFECT);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const {
|
|
|
|
|
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(main);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
|
|
bool replace = false;
|
|
|
|
|
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
|
|
|
|
|
auto main_value = GetValueNode(main);
|
|
|
|
|
auto node_value = GetValueNode(node);
|
|
|
|
|
if (main_value->isa<Primitive>() && node_value->isa<Primitive>()) {
|
|
|
|
|
replace = false;
|
|
|
|
|
return false;
|
|
|
|
|
} else if (main_value->isa<tensor::Tensor>() && node_value->isa<tensor::Tensor>()) {
|
|
|
|
|
replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node);
|
|
|
|
|
return (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node);
|
|
|
|
|
} else {
|
|
|
|
|
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
|
|
|
|
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
|
|
|
|
}
|
|
|
|
|
} else if (main->isa<CNode>() && node->isa<CNode>()) {
|
|
|
|
|
if (check_side_effect && HasSideEffectAttr(main)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (!CheckEqualKernelBuildInfo(main, node)) {
|
|
|
|
|
replace = false;
|
|
|
|
|
} else {
|
|
|
|
|
auto c_main = main->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_main);
|
|
|
|
|
auto c_node = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_node);
|
|
|
|
|
const auto &inp1 = c_main->inputs();
|
|
|
|
|
const auto &inp2 = c_node->inputs();
|
|
|
|
|
if (inp1.size() == inp2.size()) {
|
|
|
|
|
bool appsame = true;
|
|
|
|
|
for (size_t j = 0; j < inp1.size(); j++) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp1[j]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp2[j]);
|
|
|
|
|
if (!(*inp1[j] == *inp2[j])) {
|
|
|
|
|
appsame = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
replace = appsame;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto c_main = main->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_main);
|
|
|
|
|
auto c_node = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_node);
|
|
|
|
|
const auto &inp1 = c_main->inputs();
|
|
|
|
|
const auto &inp2 = c_node->inputs();
|
|
|
|
|
if (inp1.size() != inp2.size()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = 0; j < inp1.size(); j++) {
|
|
|
|
|
auto inp1_j = inp1[j];
|
|
|
|
|
auto inp2_j = inp2[j];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp1_j);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp2_j);
|
|
|
|
|
if (!(*inp1_j == *inp2_j)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return replace;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|