|
|
|
@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
|
|
|
|
|
return changed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|
|
|
|
bool has_random_effect = false;
|
|
|
|
|
auto prim_main = GetCNodePrimitive(main);
|
|
|
|
|
auto prim_node = GetCNodePrimitive(node);
|
|
|
|
|
if (prim_main == prim_node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (prim_main != nullptr) {
|
|
|
|
|
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
|
|
|
|
|
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
|
|
|
|
|
has_random_effect = GetValue<bool>(effect_val);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return has_random_effect;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(main);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) {
|
|
|
|
|
if (CheckRandomEffect(c_main, c_node)) {
|
|
|
|
|
appsame = false;
|
|
|
|
|
}
|
|
|
|
|
replace = appsame;
|
|
|
|
|