diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h index 0eea40dfab..56a3c63cd8 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -98,19 +98,12 @@ class ConvertSwitchReplacement : public OptimizerCaller { return nullptr; } - auto cnode_ = node->cast(); - if (cnode_->size() < 1) { - return nullptr; - } - - auto node_ = cnode_->input(0); - PatternNode cond, true_br, false_br; - auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto g1_ = GetValueNode(true_br.GetNode(node_)); - auto g2_ = GetValueNode(false_br.GetNode(node_)); - auto x_ = cond.GetNode(node_); + auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto g1_ = GetValueNode(true_br.GetNode(node)); + auto g2_ = GetValueNode(false_br.GetNode(node)); + auto x_ = cond.GetNode(node); // for switch replace method, only graphs without graph inside can be replaced for (auto &item : g1_->value_nodes()) { @@ -133,7 +126,7 @@ class ConvertSwitchReplacement : public OptimizerCaller { auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); std::vector params; - auto fg = node_->func_graph(); + auto fg = node->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); @@ -142,8 +135,8 @@ class ConvertSwitchReplacement : public OptimizerCaller { }; MATCH_REPLACE_LAMBDA_IF( - node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, - true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); + node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda, + true_br.CheckFunc(IsValueNode, node) && false_br.CheckFunc(IsValueNode, node)); return nullptr; } diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 2ff6d3ceb7..327e47e20d 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -443,7 +443,6 @@ class PConstant : public PBase > { } bool TryCapture_(const AnfNodePtr &node) const { - // if (IsValueNode(node)) { if (node->isa()) { // If any_value_ is set don't check for the node's value. Just capture it. if (any_value_) { @@ -726,7 +725,7 @@ class PConstant : public PBase > { ret = memcpy_s(data, mem_size, data_out, mem_size); delete[] reinterpret_cast(data_out); } else { - // Un-support data types + // Unsupported data types return nullptr; } }