!7159 optimize graph cut for depend

Merge pull request !7159 from kisnwang/optimize-graph-depend-cut
pull/7159/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8c329605d2

@ -516,8 +516,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
MS_EXCEPTION_IF_NULL(other_graph_cnode);
MS_EXCEPTION_IF_NULL(cnode_inputs);
auto origin_inputs = cnode->inputs();
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>();
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3;
bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3;
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
@ -527,6 +526,9 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
continue;
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
cnode_inputs->push_back((*other_graph_cnode)[anf]);
continue;
@ -546,9 +548,6 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
(*other_graph_cnode)[anf] = new_parameter;
}
continue;
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
cnode_inputs->push_back(origin_inputs[kRealInputIndexInDepend]);
continue;
} else if (optimize_control_depend) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
} else {

@ -141,10 +141,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
}
auto fn = inps[0];
std::vector<AnfNodePtr> args{fn};
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa<ValueNode>() &&
eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
args.emplace_back(inps[kRealInputIndexInDepend]);
args.emplace_back(inps[kRealInputIndexInDepend]);
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
args.emplace_back(NewValueNode(MakeValue(0)));
} else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) {
for (size_t i = 1; i < inps.size(); ++i) {
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {

@ -289,10 +289,14 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
}
return target;
}
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
auto &inputs = cnode->inputs();
if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[1]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
return GetMaketupleNodeTarget(cnode);
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return GetTupleGetItemTarget(cnode, primitive);
}
return default_target;

Loading…
Cancel
Save