diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc index db3e1f9822..c36dde8cfa 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc @@ -39,9 +39,7 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A } // If there is a side-effect operator in the fusion, do not merge - MonadState state_matmul = GetMonadState(matmul); - MonadState state_node = GetMonadState(node, matmul); - if (!IsStateEquivalent(state_matmul, state_node)) { + if (!IsStateEquivalent(node, matmul)) { return node; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h index 2072c08ef3..800c66b6fe 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h @@ -71,16 +71,15 @@ class MergeAddN : public AnfVisitor { is_inner_ = true; // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys} - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]); + const auto &first_input = inputs.at(1); + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(first_input); if (is_match_) { - if (!is_unique(inputs[1])) { + if (!is_unique(first_input)) { is_match_ = false; return; } - MonadState state_input = GetMonadState(inputs[1]); - MonadState state_cnode = GetMonadState(cnode, inputs[1]); - if (!IsStateEquivalent(state_cnode, state_input)) { + if (!IsStateEquivalent(cnode, first_input)) { is_match_ = false; return; } @@ -92,16 +91,15 @@ class MergeAddN : public AnfVisitor { } // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}} - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back()); + const auto &last_input = inputs.back(); + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(last_input); if (is_match_) { - if (!is_unique(inputs.back())) { + if (!is_unique(last_input)) { is_match_ = false; return; } - MonadState state_input = GetMonadState(inputs.back()); - MonadState state_cnode = GetMonadState(cnode, inputs.back()); - if (!IsStateEquivalent(state_cnode, state_input)) { + if (!IsStateEquivalent(cnode, last_input)) { is_match_ = false; return; } diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 33e50fa347..379d0919a2 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -300,6 +300,46 @@ bool IsStateEquivalent(const MonadState &state1, const MonadState &state2) { (state1.io == nullptr || state2.io == nullptr || state1.io == state2.io); } +bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) { + MonadState state_matmul = GetMonadState(inner); + MonadState state_node = GetMonadState(outer, inner); + return IsStateEquivalent(state_matmul, state_node); +} + +std::set GetLoadInputs(const AnfNodePtr &node) { + std::set loads; + auto cnode = dyn_cast(node); + if (cnode == nullptr) { + return loads; + } + auto &inputs = cnode->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto &input = inputs.at(i); + if (IsPrimitiveCNode(input, prim::kPrimLoad)) { + loads.insert(input->cast()); + } else if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + loads.merge(GetLoadInputs(input)); + } + } + return loads; +} + +bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) { + constexpr size_t kMonadInput = 2; + auto outer_loads = GetLoadInputs(outer); + if (outer_loads.empty()) { + return true; + } + auto inner_loads = GetLoadInputs(inner); + if (inner_loads.empty()) { + return true; + } + outer_loads.merge(inner_loads); + auto &monad = (*outer_loads.begin())->inputs().at(kMonadInput); + return std::all_of(++outer_loads.begin(), outer_loads.end(), + [&monad](const CNodePtr &load) { return load->inputs().at(kMonadInput) == monad; }); +} + size_t NewSeenGeneration() { static size_t seen_generation = 0; return ++seen_generation; @@ -353,6 +393,26 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); return default_target; } + +std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_target, const AnfNodePtr &attr_input, + const std::string &primitive_target, const std::string &default_target) { + if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || + IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || + IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || + IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || + IsPrimitive(attr_input, prim::kPrimPartial)) { + primitive->EraseAttr(primitive_target); + return default_target; + } + if (!att_target->isa()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + auto target = GetValue(att_target); + if (kTargetSet.find(target) == kTargetSet.end()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target; + } + return target; +} } // namespace std::string GetCNodeTarget(const AnfNodePtr &node) { @@ -387,22 +447,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { auto primitive = value->cast(); auto att_target = primitive->GetAttr(primitive_target); if (att_target != nullptr) { - if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || - IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || - IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || - IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || - IsPrimitive(attr_input, prim::kPrimPartial)) { - primitive->EraseAttr(primitive_target); - return default_target; - } - if (!att_target->isa()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - auto target = GetValue(att_target); - if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target; - } - return target; + return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target); } if (IsPrimitiveCNode(node, prim::kPrimDepend)) { auto &inputs = cnode->inputs(); diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 27b2399077..08c4a60b1b 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -530,6 +530,12 @@ MonadState GetMonadState(const AnfNodePtr &node, const AnfNodePtr &skip_input = // Check if two state is equivalent. bool IsStateEquivalent(const MonadState &state1, const MonadState &state2); +// Check if monad state is strict equivalent for the connected two nodes. +bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner); + +// Check if monad state is equivalent for the connected two nodes, not strict but more faster. +bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner); + // used to check whether a ValueNode has some kind of value template static bool IsValueNode(const AnfNodePtr &node) {