diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 5aeda36230..11e44154c3 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -640,6 +640,16 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP MS_EXCEPTION_IF_NULL(func_graph_node); auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node); ConstructKernelGraph(sub_func_graph); + } else if (prim->name() == kReturnOpName) { + std::vector outputs; + auto inputs = cnode->inputs(); + if (inputs.size() < 2) { + MS_LOG(EXCEPTION) << "CNode[return] must have two inputs at least, actual inputs size is " << inputs.size(); + } + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outputs)); + // add a make_tuple before return as graph output + graph->set_output(ConstructOutput(outputs, graph)); + continue; } } @@ -649,11 +659,6 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP new_cnode->set_abstract(cnode->abstract()); new_cnode->set_scope(cnode->scope()); graph->FrontBackendlMapAdd(node, new_cnode); - - // set original return to kernel_graph - if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) { - graph->set_return(new_cnode); - } } } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5caa282e26..8eacd10be0 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -144,6 +144,7 @@ constexpr auto kBNInferGradOpName = "BNInferGrad"; constexpr auto kCallOpName = "call"; constexpr auto kPartialOpName = "partial"; constexpr auto kSwitchOpName = "switch"; +constexpr auto kReturnOpName = "return"; constexpr auto kLarsV2OpName = "LarsV2"; constexpr auto kLarsV2UpdateOpName = "LarsV2Update"; constexpr auto kSquareSumAllOpName = "SquareSumAll";