!3614 Update Convert Switch to use PCNode

Merge pull request !3614 from Giancarlo/update_convert_sw
pull/3614/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 9e1244934c

@ -98,19 +98,12 @@ class ConvertSwitchReplacement : public OptimizerCaller {
return nullptr; return nullptr;
} }
auto cnode_ = node->cast<CNodePtr>();
if (cnode_->size() < 1) {
return nullptr;
}
auto node_ = cnode_->input(0);
PatternNode<AnfNodePtr> cond, true_br, false_br; PatternNode<AnfNodePtr> cond, true_br, false_br;
auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_)); auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node));
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_)); auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node));
auto x_ = cond.GetNode(node_); auto x_ = cond.GetNode(node);
// for switch replace method, only graphs without graph inside can be replaced // for switch replace method, only graphs without graph inside can be replaced
for (auto &item : g1_->value_nodes()) { for (auto &item : g1_->value_nodes()) {
@ -133,7 +126,7 @@ class ConvertSwitchReplacement : public OptimizerCaller {
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);
std::vector<AnfNodePtr> params; std::vector<AnfNodePtr> params;
auto fg = node_->func_graph(); auto fg = node->func_graph();
auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g1 = InlineClone(trans_g1, fg, params);
auto cloned_g2 = InlineClone(trans_g2, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params);
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
@ -142,8 +135,8 @@ class ConvertSwitchReplacement : public OptimizerCaller {
}; };
MATCH_REPLACE_LAMBDA_IF( MATCH_REPLACE_LAMBDA_IF(
node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda,
true_br.CheckFunc(IsValueNode<FuncGraph>, node_) && false_br.CheckFunc(IsValueNode<FuncGraph>, node_)); true_br.CheckFunc(IsValueNode<FuncGraph>, node) && false_br.CheckFunc(IsValueNode<FuncGraph>, node));
return nullptr; return nullptr;
} }

@ -443,7 +443,6 @@ class PConstant : public PBase<PConstant<T> > {
} }
bool TryCapture_(const AnfNodePtr &node) const { bool TryCapture_(const AnfNodePtr &node) const {
// if (IsValueNode<Value>(node)) {
if (node->isa<ValueNode>()) { if (node->isa<ValueNode>()) {
// If any_value_ is set don't check for the node's value. Just capture it. // If any_value_ is set don't check for the node's value. Just capture it.
if (any_value_) { if (any_value_) {
@ -726,7 +725,7 @@ class PConstant : public PBase<PConstant<T> > {
ret = memcpy_s(data, mem_size, data_out, mem_size); ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<int *>(data_out); delete[] reinterpret_cast<int *>(data_out);
} else { } else {
// Un-support data types // Unsupported data types
return nullptr; return nullptr;
} }
} }

Loading…
Cancel
Save