|
|
|
@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
|
|
|
|
|
|
|
|
|
|
return changed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The op like print, summary, or the op do not has true output, and always as a depend node input.
|
|
|
|
|
static bool HasSideEffect(const AnfNodePtr &node) {
|
|
|
|
|
auto prim = GetCNodePrimitive(node);
|
|
|
|
|
if (prim == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT);
|
|
|
|
|
if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) {
|
|
|
|
|
return GetValue<bool>(side_effect_v);
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// If true do not merge the node.
|
|
|
|
|
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|
|
|
|
bool has_random_effect = false;
|
|
|
|
|
auto prim_main = GetCNodePrimitive(main);
|
|
|
|
|
auto prim_node = GetCNodePrimitive(node);
|
|
|
|
|
if (prim_main == prim_node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// if has random effect, when generate by different op (not same object), do not merge.
|
|
|
|
|
if (prim_main != nullptr) {
|
|
|
|
|
if (prim_main == prim_node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
|
|
|
|
|
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
|
|
|
|
|
has_random_effect = GetValue<bool>(effect_val);
|
|
|
|
@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
|
|
|
|
|
return has_random_effect;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|
|
|
|
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(main);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
|
|
bool replace = false;
|
|
|
|
|
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
|
|
|
|
|
auto main_value = GetValueNode(main);
|
|
|
|
|
auto node_value = GetValueNode(node);
|
|
|
|
|
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
|
|
|
|
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
|
|
|
|
} else if (main->isa<CNode>() && node->isa<CNode>()) {
|
|
|
|
|
auto c_main = main->cast<CNodePtr>();
|
|
|
|
|
auto c_node = node->cast<CNodePtr>();
|
|
|
|
|
// When appsame is true, check if has side effect, do not merge.
|
|
|
|
|
if (check_side_effect && HasSideEffect(main)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const auto &inp1 = c_main->inputs();
|
|
|
|
|
const auto &inp2 = c_node->inputs();
|
|
|
|
|
if (inp1.size() == inp2.size()) {
|
|
|
|
|
bool appsame = true;
|
|
|
|
|
for (size_t j = 0; j < inp1.size(); j++) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp1[j]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp2[j]);
|
|
|
|
|
if (!(*inp1[j] == *inp2[j])) {
|
|
|
|
|
// Handle the case of two different Tensor, but with the same value
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) {
|
|
|
|
|
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]);
|
|
|
|
|
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]);
|
|
|
|
|
if (tensor1->ValueEqual(*tensor2)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (inp1.size() != inp2.size()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = 0; j < inp1.size(); j++) {
|
|
|
|
|
auto inp1_j = inp1[j];
|
|
|
|
|
auto inp2_j = inp2[j];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp1_j);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inp2_j);
|
|
|
|
|
if (!(*inp1_j == *inp2_j)) {
|
|
|
|
|
// Handle the case of two different Tensor, but with the same value
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
|
|
|
|
|
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
|
|
|
|
|
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
|
|
|
|
|
if (tensor1->ValueEqual(*tensor2)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
} else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) {
|
|
|
|
|
// When the same side effect node as another two nodes' inputs, we still merge the node.
|
|
|
|
|
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
|
|
|
|
|
// node.
|
|
|
|
|
if (CheckReplace(inp1_j, inp2_j, false)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
appsame = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (CheckRandomEffect(c_main, c_node)) {
|
|
|
|
|
appsame = false;
|
|
|
|
|
}
|
|
|
|
|
replace = appsame;
|
|
|
|
|
}
|
|
|
|
|
// When appsame is true, check if has random effect do not merge
|
|
|
|
|
if (CheckRandomEffect(c_main, c_node)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return replace;
|
|
|
|
|
// a parameter node.
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
|
|
|
|
|