|
|
|
@ -523,12 +523,22 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph,
|
|
|
|
|
const std::vector<AnfNodePtr> orig_inputs) {
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {
|
|
|
|
|
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
|
|
|
|
std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs));
|
|
|
|
|
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
|
|
|
|
|
|
|
|
|
InsertDependToGraph(graph, NOT_NULL(make_tuple));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
|
|
|
|
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
|
|
|
|
MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
|
|
|
|
|
|
|
|
|
|
// 1 get kernel graph
|
|
|
|
|
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
|
|
|
|
|
std::vector<AnfNodePtr> origin_inputs = cur_node->inputs();
|
|
|
|
|
if (kCNodeCallArg >= origin_inputs.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
|
|
|
|
|
}
|
|
|
|
@ -555,6 +565,8 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
|
|
|
|
|
cur_node->set_inputs(new_inputs);
|
|
|
|
|
cur_node->set_abstract(nullptr);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get());
|
|
|
|
|
origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end());
|
|
|
|
|
AttachOriginalInputsToGraph(kg, origin_inputs);
|
|
|
|
|
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -587,11 +599,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|
|
|
|
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
|
|
|
|
|
// 3.1 branch kernel graph and args
|
|
|
|
|
KernelGraphPtr branch_fg;
|
|
|
|
|
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
std::vector<AnfNodePtr> origin_inputs;
|
|
|
|
|
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
child_graphs.push_back(branch_fg);
|
|
|
|
|
// 3.2 recurse sub graph
|
|
|
|
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
|
|
|
|
new_switch_inputs.push_back(branch_label);
|
|
|
|
|
AttachOriginalInputsToGraph(kg, origin_inputs);
|
|
|
|
|
}
|
|
|
|
|
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
|
|
|
|
|
|
|
|
|
@ -635,11 +649,13 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|
|
|
|
for (size_t i = 0; i < branch_partial.size(); ++i) {
|
|
|
|
|
// 3.1 branch kernel graph and args
|
|
|
|
|
KernelGraphPtr branch_fg;
|
|
|
|
|
std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
std::vector<AnfNodePtr> origin_inputs;
|
|
|
|
|
std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
child_graphs.push_back(branch_fg);
|
|
|
|
|
// 3.2 recurse sub graph
|
|
|
|
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
|
|
|
|
new_switch_inputs.push_back(branch_label);
|
|
|
|
|
AttachOriginalInputsToGraph(kg, origin_inputs);
|
|
|
|
|
}
|
|
|
|
|
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
|
|
|
|
|
cur_node->set_inputs(new_switch_inputs);
|
|
|
|
|