diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index f95b7f40d6..5662fdb446 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1220,6 +1220,5 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) { } return true; } - } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index e4a983c2fe..76341fbf5a 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace session { namespace { -void UpdateOutputTensors(VectorRef *outputs, +void UpdateOutputTensors(const VectorRef *outputs, const std::map &tensor_to_node) { MS_EXCEPTION_IF_NULL(outputs); for (auto item : *outputs) { diff --git a/mindspore/ccsrc/backend/session/infer_session.cc b/mindspore/ccsrc/backend/session/infer_session.cc index 068a378f75..e955f70034 100644 --- a/mindspore/ccsrc/backend/session/infer_session.cc +++ b/mindspore/ccsrc/backend/session/infer_session.cc @@ -35,7 +35,6 @@ using std::vector; namespace py = pybind11; namespace mindspore { namespace inference { - std::shared_ptr InferSession::CreateSession(const std::string &device, uint32_t device_id) { try { auto session = std::make_shared(); @@ -271,36 +270,18 @@ void MSInferSession::RegAllOp() { MsContext::GetInstance()->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); Py_Initialize(); auto c_expression = PyImport_ImportModule("mindspore._c_expression"); - if (c_expression == nullptr) { - MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module."; - return; - } + MS_EXCEPTION_IF_NULL(c_expression); PyObject *c_expression_dict = PyModule_GetDict(c_expression); - if (c_expression_dict == nullptr) { - MS_LOG(EXCEPTION) << "Failed to get dict from mindspore._c_expression module."; - return; - } + MS_EXCEPTION_IF_NULL(c_expression_dict); PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); - if (op_info_loader_class == nullptr) { - MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression."; - return; - } + MS_EXCEPTION_IF_NULL(op_info_loader_class); PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); - if (op_info_loader == nullptr) { - MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance."; - return; - } + MS_EXCEPTION_IF_NULL(op_info_loader); PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); - if (op_info_loader_ins == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance."; - return; - } + MS_EXCEPTION_IF_NULL(op_info_loader_ins); auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); - if (all_ops_info_vector_addr_ul == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr."; - return; - } + MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul); auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); auto all_ops_info = static_cast *>(all_ops_info_vector_addr); for (auto op_info : *all_ops_info) { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index c4f04c1bfe..ea918c15c0 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -494,54 +494,52 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern return make_tuple; } -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, - std::unordered_map *other_graph_cnode) { +void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector *cnode_inputs) { MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(other_graph_cnode); - // get primitive of old node - std::vector cnode_inputs; + MS_EXCEPTION_IF_NULL(cnode_inputs); auto prim = AnfAlgo::GetCNodePrimitive(cnode); if (prim != nullptr) { // push attr to inputs[0] of new cnode - cnode_inputs.push_back(std::make_shared(std::make_shared(*prim))); + cnode_inputs->push_back(std::make_shared(std::make_shared(*prim))); } else { auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); MS_EXCEPTION_IF_NULL(fg); auto new_fg = BasicClone(fg); - cnode_inputs.push_back(std::make_shared(new_fg)); + cnode_inputs->push_back(std::make_shared(new_fg)); } +} + +void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs, + std::unordered_map *other_graph_cnode) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(other_graph_cnode); + MS_EXCEPTION_IF_NULL(cnode_inputs); auto origin_inputs = cnode->inputs(); - bool optimize_depend = false; - bool optimize_control_depend = false; - if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && - origin_inputs[kRealInputIndexInDepend]->isa()) { - optimize_depend = true; - } - if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) { - optimize_control_depend = true; - } + bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && + origin_inputs[kRealInputIndexInDepend]->isa(); + bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3; // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { auto anf = origin_inputs[input_idx]; MS_EXCEPTION_IF_NULL(anf); // anf has been created before if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { - cnode_inputs.push_back((*other_graph_cnode)[anf]); + cnode_inputs->push_back((*other_graph_cnode)[anf]); continue; } else if (anf->isa() && !IsValueNode(anf)) { // if input is a value node, auto new_value_node = CreateNewValueNode(anf, graph); if (new_value_node != nullptr) { - cnode_inputs.emplace_back(new_value_node); + cnode_inputs->emplace_back(new_value_node); } continue; } else if (anf->isa()) { auto new_parameter = CreateNewParameterFromParameter(anf, graph); - cnode_inputs.push_back(new_parameter); + cnode_inputs->push_back(new_parameter); if (GetGraphIdByNode(anf) == kInvalidGraphId) { graph->FrontBackendlMapAdd(anf, new_parameter); } else { @@ -549,20 +547,31 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, } continue; } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { - cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); + cnode_inputs->push_back(origin_inputs[kRealInputIndexInDepend]); continue; } else if (optimize_control_depend) { - cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); + cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); } else { // the input node is a cnode from other graph auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph); if (parameter_from_cnode == nullptr) { parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); } - cnode_inputs.push_back(parameter_from_cnode); + cnode_inputs->push_back(parameter_from_cnode); (*other_graph_cnode)[anf] = parameter_from_cnode; } } +} + +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, + std::unordered_map *other_graph_cnode) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(other_graph_cnode); + // get primitive of old node + std::vector cnode_inputs; + GetCNodeInfo(cnode, &cnode_inputs); + GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode); TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); auto new_cnode = graph->NewCNode(cnode_inputs); TraceManager::EndTrace(); @@ -593,6 +602,42 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra return partial_node; } +std::vector SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + std::vector cnode_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + auto switch_cnode = cnode_input->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + if (cnode->inputs().size() < 2) { + cnode_inputs = switch_cnode->inputs(); + return cnode_inputs; + } + std::vector switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), + switch_cnode->input(kFirstDataInputIndex)}; + for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { + auto node = switch_cnode->input(index); + // there is real input in call, should put it to true and false branch in switch + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { + auto partial_node = node->cast(); + MS_EXCEPTION_IF_NULL(partial_node); + std::vector partial_inputs = partial_node->inputs(); + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); + auto new_partial = graph->NewCNode(partial_inputs); + switch_inputs.emplace_back(new_partial); + } + } + if (switch_inputs.size() < kSwitchInputSize) { + MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; + } + auto switch_node = graph->NewCNode(switch_inputs); + cnode_inputs.emplace_back(switch_node); + return cnode_inputs; +} + std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); @@ -618,32 +663,7 @@ std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & }); return cnode_inputs; } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - auto switch_cnode = cnode_input->cast(); - MS_EXCEPTION_IF_NULL(switch_cnode); - if (cnode->inputs().size() < 2) { - cnode_inputs = switch_cnode->inputs(); - return cnode_inputs; - } - std::vector switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), - switch_cnode->input(kFirstDataInputIndex)}; - for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { - auto node = switch_cnode->input(index); - // there is real input in call, should put it to true and false branch in switch - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { - auto partial_node = node->cast(); - MS_EXCEPTION_IF_NULL(partial_node); - std::vector partial_inputs = partial_node->inputs(); - partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); - auto new_partial = graph->NewCNode(partial_inputs); - switch_inputs.emplace_back(new_partial); - } - } - if (switch_inputs.size() < kSwitchInputSize) { - MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; - } - auto switch_node = graph->NewCNode(switch_inputs); - cnode_inputs.emplace_back(switch_node); - return cnode_inputs; + return CreateCallSwitchInputs(cnode, graph); } MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; } diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 7e1f518dca..33864f57e1 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -131,6 +131,10 @@ class SessionBasic : public std::enable_shared_from_this { std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); std::vector CreateValueNode(const CNodePtr &cnode, KernelGraph *graph); void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs); + std::vector CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph); + void GetCNodeInfo(const CNodePtr &cnode, std::vector *cnode_inputs); + void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs, + std::unordered_map *other_graph_cnode); protected: void RunInfer(NotNull func_graph, const std::vector &inputs);