!12732 fix switch layer

From: @youui
Reviewed-by: @guoqi1024,@zhoufeng54
Signed-off-by: @guoqi1024
pull/12732/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a5af03f8ca

@ -270,8 +270,9 @@ class AscendAutoMonadConverter {
MS_LOG(EXCEPTION) << "Invalid CNode: " << cnode->DebugString() << std::endl;
}
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
// Found call/switch node, set it as the tail call node.
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
// Found call/switch/switchlayer node, set it as the tail call node.
tail_call_node_ = cnode;
call_switch_nodes_.emplace_back(cnode);
monad_map_.emplace(cnode, last_monad);
@ -292,8 +293,10 @@ class AscendAutoMonadConverter {
HandleCall(cnode);
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
HandleSwitch(cnode);
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
HandleSwitchLayer(cnode);
} else {
MS_LOG(EXCEPTION) << "Not a call/switch node: " << cnode->DebugString();
MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString();
}
}
// If no tail call, assign output value to output parameter,
@ -413,6 +416,60 @@ class AscendAutoMonadConverter {
}
}
//
// Convert switch node:
// branch1 = Partial(graph1, arg)
// branch2 = Partial(graph2, arg)
// out = SwitchLayer(index, branch1, branch2)
// to:
// r = link_args(graph1, arg)
// c = UpdateState(c, r)
// r = link_args(graph2, arg)
// c = UpdateState(c, r)
// c = LabelSwitch(index, c) : L1, L2
// c = LabelSet(c) : <return label>
//
void HandleSwitchLayer(const CNodePtr &cnode) {
// Update last_monad_.
last_monad_ = monad_map_[cnode];
// Get both branches of the switch, true branch first.
auto branches = GetSwitchBranches(cnode);
// Link arguments and generate labels for branches.
std::vector<KernelGraphPtr> graphes;
std::vector<uint32_t> labels;
graphes.reserve(branches.size());
labels.reserve(graphes.size());
for (auto &[graph, args] : branches) {
if (graph == nullptr) {
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
}
auto linked_args = LinkArguments(args, graph);
if (linked_args != nullptr) {
monad_ = UpdateState(GetMonad(), linked_args);
}
graphes.push_back(graph);
labels.push_back(GetOrCreateGraphLabel(graph));
}
// Add LabelSwith node.
auto switch_node = LabelSwitch(cnode->input(1), labels);
// Set child graph attribute for switch node.
SetChildGrapAttr(switch_node, graphes);
// Setup return label if required.
const bool is_tail_call = (cnode == tail_call_node_);
const bool need_return = (return_label_ == kNoLabel || !is_tail_call);
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
// Handle sub-graphs recursively.
for (auto &graph : graphes) {
HandleSubGraph(graph, para_pool, output_para, return_label);
}
}
ParameterPoolPtr GetParameterPool(bool is_last_call) {
if (!is_last_call) {
// There are multiple calls in this graph, use a new parameter pool
@ -483,10 +540,13 @@ class AscendAutoMonadConverter {
}
std::vector<GraphArgPair> GetSwitchBranches(const CNodePtr &cnode) {
constexpr size_t true_index = 2;
constexpr size_t false_index = 3;
// True branch first, then false branch.
return {GetSwitchBranch(cnode, true_index), GetSwitchBranch(cnode, false_index)};
constexpr size_t cond_start_index = 2;
// switch branches
std::vector<GraphArgPair> switch_branches;
for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) {
switch_branches.emplace_back(GetSwitchBranch(cnode, index));
}
return switch_branches;
}
//

@ -928,9 +928,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
return cnode_inputs;
}
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) {
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(real_input);
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) {
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node.";
}
@ -940,24 +939,37 @@ void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const Anf
auto ret = partial_kernel_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
// if kernel graph return node is a function
// return node is a function
std::vector<AnfNodePtr> call_inputs = {
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
AnfNodePtr real_kernel_graph;
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
std::vector<AnfNodePtr> call_inputs = {
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
auto return_input_cnode = return_input->cast<CNodePtr>();
auto partial_inputs = return_input_cnode->inputs();
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
real_kernel_graph = partial_inputs[kFirstDataInputIndex];
} else { // return node is kernel graph
call_inputs.emplace_back(return_input);
real_kernel_graph = return_input;
}
// new call node inputs
for (auto real_input : real_inputs) {
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get());
call_inputs.emplace_back(parameter_for_input);
auto call_node = partial_kernel_graph->NewCNode(call_inputs);
// update abstract
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_inputs[kFirstDataInputIndex]);
}
auto call_node = partial_kernel_graph->NewCNode(call_inputs);
// update abstract
MS_EXCEPTION_IF_NULL(real_kernel_graph);
if (real_kernel_graph->isa<ValueNode>() && IsValueNode<FuncGraph>(real_kernel_graph)) {
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(real_kernel_graph);
MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph);
auto ret_partial = sub_partial_kernel_graph->get_return();
call_node->set_abstract(ret_partial->abstract());
// update return input
ret->set_input(kFirstDataInputIndex, call_node);
}
// update return input
ret->set_input(kFirstDataInputIndex, call_node);
}
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
@ -977,9 +989,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
auto node = make_tuple_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto make_tuple_inputs = node->inputs();
// there is real input in call, should put it to make_tuple in switch_layer
auto real_input = cnode->input(kFirstDataInputIndex);
auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input);
// there are real inputs in call, should put it to make_tuple in switch_layer
std::vector<AnfNodePtr> real_inputs;
for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) {
real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx)));
}
std::vector<AnfNodePtr> new_make_tuple_inputs = {
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
@ -990,10 +1004,18 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
auto partial_node = partial_idx->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_node);
// update kernel graph when switch_layer node return function
CreateCallNodeReturnFunction(partial_node, real_input_back);
auto partial_input = partial_node->input(kFirstDataInputIndex);
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
auto ret = partial_kernel_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode<KernelGraph>(return_input)) {
CreateCallNodeReturnFunction(partial_node, real_inputs);
}
std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs();
new_partial_inputs.emplace_back(real_input_back);
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
auto new_partial = graph->NewCNode(new_partial_inputs);
new_make_tuple_inputs.emplace_back(new_partial);
}
@ -1003,7 +1025,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
std::vector<AnfNodePtr> new_partial_inputs;
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
new_partial_inputs.emplace_back(partial_idx);
new_partial_inputs.emplace_back(real_input_back);
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
auto new_partial = graph->NewCNode(new_partial_inputs);
new_make_tuple_inputs.emplace_back(new_partial);
}

@ -147,7 +147,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
void CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input);
void CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs);
protected:
friend class Executor;

Loading…
Cancel
Save