|
|
|
@ -411,8 +411,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|
|
|
|
origin_switch_inputs[kCNodeSwitchCond]};
|
|
|
|
|
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
|
|
|
|
|
// 3.1 branch kernel graph and args
|
|
|
|
|
KernelGraphPtr branch_fg;
|
|
|
|
|
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
// 3.2 recurse sub graph
|
|
|
|
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
|
|
|
|
new_switch_inputs.push_back(branch_label);
|
|
|
|
@ -456,8 +455,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|
|
|
|
origin_switch_inputs[kCNodeSwitchCond]};
|
|
|
|
|
for (size_t i = 0; i < branch_partial.size(); ++i) {
|
|
|
|
|
// 3.1 branch kernel graph and args
|
|
|
|
|
KernelGraphPtr branch_fg;
|
|
|
|
|
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
|
|
|
|
// 3.2 recurse sub graph
|
|
|
|
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
|
|
|
|
|
new_switch_inputs.push_back(branch_label);
|
|
|
|
@ -468,8 +466,11 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|
|
|
|
MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
|
|
|
|
|
KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
|
|
|
|
|
if (!node.get()->isa<CNode>()) {
|
|
|
|
|
if (IsValueNode<KernelGraph>(node)) {
|
|
|
|
|
return GetValueNode<KernelGraphPtr>(node);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
// 2.1 branch kernel graph and args
|
|
|
|
@ -484,7 +485,7 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
|
|
|
|
|
}
|
|
|
|
|
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
|
|
|
|
|
return {partial_cnode, branch_kg};
|
|
|
|
|
return branch_kg;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
|
|
|
|
|