|
|
|
@ -32,19 +32,23 @@ using mindspore::abstract::AbstractBase;
|
|
|
|
|
using mindspore::abstract::AbstractFunction;
|
|
|
|
|
using mindspore::abstract::AbstractFunctionPtr;
|
|
|
|
|
|
|
|
|
|
BasePtr AbsOf(const AnfNodePtr &node) {
|
|
|
|
|
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto node_abs = node->abstract();
|
|
|
|
|
// in testcase: TestOptOpt.CSE, node->abstract() is null;
|
|
|
|
|
// In testcase: TestOptOpt.CSE, node->abstract() is null.
|
|
|
|
|
if (node_abs == nullptr) {
|
|
|
|
|
return kAnyValue;
|
|
|
|
|
}
|
|
|
|
|
// Ignore the tracking_id and prim pointer hash;
|
|
|
|
|
if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
|
|
|
|
|
// Ignore the tracking_id and prim pointer hash.
|
|
|
|
|
auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>();
|
|
|
|
|
return prim_abs->prim();
|
|
|
|
|
} else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
|
|
|
|
// Ignore the tracking_id.
|
|
|
|
|
auto new_fg_abs = node_abs->cast<abstract::AbstractFunctionPtr>()->Copy();
|
|
|
|
|
new_fg_abs->set_tracking_id(nullptr);
|
|
|
|
|
return new_fg_abs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return node_abs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -68,7 +72,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
|
|
|
|
|
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
|
|
|
|
auto value = value_node->value();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
|
h = hash_combine(value->hash(), (AbsOf(value_node)->hash()));
|
|
|
|
|
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
|
|
|
|
|
} else if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
@ -134,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec
|
|
|
|
|
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
|
|
|
|
|
auto main_value = GetValueNode(main);
|
|
|
|
|
auto node_value = GetValueNode(node);
|
|
|
|
|
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
|
|
|
|
return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value);
|
|
|
|
|
} else if (main->isa<CNode>() && node->isa<CNode>()) {
|
|
|
|
|
auto c_main = main->cast<CNodePtr>();
|
|
|
|
|
auto c_node = node->cast<CNodePtr>();
|
|
|
|
|