From 7d41812b98002d600d5487fac4c8500a44df9fc8 Mon Sep 17 00:00:00 2001 From: chenfei Date: Sat, 16 May 2020 20:57:09 +0800 Subject: [PATCH] split-graph-for-control-sink --- mindspore/ccsrc/operator/ops.cc | 1 + mindspore/ccsrc/operator/ops.h | 1 + .../ccsrc/session/anf_runtime_algorithm.cc | 63 +++++- .../ccsrc/session/anf_runtime_algorithm.h | 3 + mindspore/ccsrc/session/ascend_session.cc | 187 +++++++++++++++++- mindspore/ccsrc/session/ascend_session.h | 5 +- mindspore/ccsrc/session/kernel_graph.cc | 59 +++++- mindspore/ccsrc/session/kernel_graph.h | 16 ++ mindspore/ccsrc/session/session_basic.cc | 54 +++-- mindspore/ccsrc/session/session_basic.h | 1 + 10 files changed, 350 insertions(+), 40 deletions(-) diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 2b0eeb26f2..da4e8983cb 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -65,6 +65,7 @@ const PrimitivePtr kPrimAssign = std::make_shared("Assign"); const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); const PrimitivePtr kPrimSelect = std::make_shared("Select"); +const PrimitivePtr kPrimCall = std::make_shared("call"); const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); const PrimitivePtr kPrimDot = std::make_shared("dot"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index e6e065f076..8b63c876ed 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -71,6 +71,7 @@ extern const PrimitivePtr kPrimAssign; extern const PrimitivePtr kPrimAssignAdd; extern const PrimitivePtr kPrimAssignSub; extern const PrimitivePtr kPrimSelect; +extern const PrimitivePtr kPrimCall; extern const PrimitivePtr kPrimDistribute; extern const PrimitivePtr kPrimDot; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 7e7f2fd408..fe2b2eb3f1 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -271,7 +271,9 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TypePtr type = node->Type(); - MS_EXCEPTION_IF_NULL(type); + if (type == nullptr) { + return 0; + } if (type->isa()) { auto tuple_type = type->cast(); MS_EXCEPTION_IF_NULL(tuple_type); @@ -913,11 +915,66 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto value_node = node->cast(); - MS_EXCEPTION_IF_NULL(value_node); + if (value_node == nullptr) { + return nullptr; + } auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); + if (value == nullptr) { + return nullptr; + } auto func_graph = value->cast(); return func_graph; } + +std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { + if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { + MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node."; + } + MS_EXCEPTION_IF_NULL(call_node); + auto input1 = call_node->input(1); + MS_EXCEPTION_IF_NULL(input1); + if (input1->isa()) { + auto value_node = input1->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto kernel_graph = value_node->value(); + MS_EXCEPTION_IF_NULL(kernel_graph); + return {kernel_graph->cast()}; + } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { + auto switch_node = input1->cast(); + MS_EXCEPTION_IF_NULL(switch_node); + MS_LOG(INFO) << "switch : " << switch_node->DebugString(); + auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr { + auto partial = switch_node->input(input_index); + MS_EXCEPTION_IF_NULL(partial); + auto partial_cnode = partial->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto graph_node = partial_cnode->input(1); + MS_EXCEPTION_IF_NULL(graph_node); + MS_LOG(INFO) << graph_node->DebugString(); + auto graph_value_node = graph_node->cast(); + MS_EXCEPTION_IF_NULL(graph_value_node); + auto graph_value = graph_value_node->value(); + MS_EXCEPTION_IF_NULL(graph_value); + auto child_graph = graph_value->cast(); + return child_graph; + }; + return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; + } + return {}; +} + +bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { + MS_EXCEPTION_IF_NULL(call_node); + if (!CheckPrimitiveType(call_node, prim::kPrimCall)) { + MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString(); + } + auto input1 = call_node->input(1); + if (input1->isa()) { + return false; + } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { + return true; + } + MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index f9b426261d..10ae5282e0 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -32,6 +32,7 @@ #include "kernel/kernel_build_info.h" #include "operator/ops.h" #include "utils/contract.h" +#include "session/kernel_graph.h" namespace mindspore { namespace session { @@ -182,6 +183,8 @@ class AnfRuntimeAlgorithm { static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsGetNext(const NotNull &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); + static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); + static bool IsSwitchCall(const CNodePtr &call_node); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 99ce090b42..862d66c25d 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -156,6 +156,89 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { } } } + +std::vector GetCNodes(const std::vector &anf_nodes) { + std::vector cnodes = {}; + size_t i = 0; + for (const auto anf : anf_nodes) { + MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); + MS_EXCEPTION_IF_NULL(anf); + if (anf->isa()) { + cnodes.push_back(anf->cast()); + } + } + return std::move(cnodes); +} + +std::vector> GetChildList(const KernelGraph &cur_graph, const std::vector &cnodes) { + size_t after_call_index = 0; + std::vector> ret; + for (size_t i = 0; i < cnodes.size(); i++) { + if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { + auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); + // if graph is the true branch of while,no need split graph + if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) { + continue; + } + auto prev_call_list = std::vector(cnodes.begin() + after_call_index, cnodes.begin() + i); + auto call_list = std::vector(1, cnodes[i]); + after_call_index = i + 1; + ret.push_back(prev_call_list); + ret.push_back(call_list); + } else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { + ret.push_back(std::vector(cnodes.begin() + after_call_index, cnodes.end())); + } + } + return std::move(ret); +} + +void UpdateRealInput(KernelGraph *graph) { + auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); + auto bind_call_partial_with_parameter = [&](const std::vector ¶meters, + const std::vector &args, KernelGraph *child_graph) -> void { + MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); + if (args.empty()) { + return; + } + if (parameters.size() != args.size()) { + MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() + << " and args size:" << args.size() << " not equal!"; + } + for (size_t i = 0; i < parameters.size(); i++) { + MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString(); + child_graph->SetRealInput(parameters[i], args[i]); + } + }; + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); + if (child_graphs.size() == 1) { + MS_EXCEPTION_IF_NULL(child_graphs[0]); + bind_call_partial_with_parameter( + child_graphs[0]->inputs(), std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()), + child_graphs[0].get()); + call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); + } else if (child_graphs.size() == 2) { + auto get_partial_args = [&](size_t input_index) -> std::vector { + auto switch_node = call_node->input(1); + MS_EXCEPTION_IF_NULL(switch_node); + auto switch_cnode = switch_node->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + auto partial = switch_cnode->input(input_index); + MS_EXCEPTION_IF_NULL(partial); + auto partial_cnode = partial->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); + partial_cnode->set_inputs( + std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); + return std::move(ret); + }; + bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); + bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); + } + } +} } // namespace GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { @@ -171,7 +254,7 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { MS_LOG(INFO) << "start"; auto graph = ConstructKernelGraph(func_graph); // split switch - SplitSwitch(graph.get()); + SplitGraph(graph); // insert goto labels and label_sets LinkChildGraphs(graph.get()); // resource initialize @@ -1297,5 +1380,107 @@ void AscendSession::SyncInitialTenosrToDevice() { } } } + +KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list) { + MS_EXCEPTION_IF_NULL(new_kernel_graph); + MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id(); + // count the output of every anf node + std::set has_output_nodes; + for (auto &anf_node : list) { + for (auto &input : anf_node->inputs()) { + (void)has_output_nodes.insert(input); + } + if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { + new_kernel_graph->set_return(anf_node->cast()); + } + } + MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); + // create new parameter from cnode + for (auto &anf_node : list) { + auto cnode = anf_node->cast(); + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { + auto input = cnode->inputs()[input_idx]; + if (!input->isa()) { + cnode->set_input(input_idx, input); + continue; + } + if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { + auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); + cnode->set_input(input_idx, new_parameter); + new_kernel_graph->SetRealInput(new_parameter, input); + } + } + } + MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); + auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); + std::vector make_tuple_inputs = {make_tuple_primitve}; + int output_idx = 0; + for (auto &anf_node : list) { + if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { + new_kernel_graph->set_return(anf_node); + } + if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { + MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString(); + make_tuple_inputs.push_back(anf_node); + } + } + if (new_kernel_graph->get_return() == nullptr) { + new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); + } + MS_LOG(INFO) << "end"; + return new_kernel_graph; +} + +void AscendSession::SplitGraph(const KernelGraphPtr &graph) { + MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); + MS_EXCEPTION_IF_NULL(graph); + auto apply_list = GetCNodes(TopoSort(graph->get_return())); + // update the root graph child graph order + graph->UpdateChildGraphOrder(); + // get child list from current graph + std::vector> child_graph_lists = GetChildList(*graph, apply_list); + auto bind_new_call_to_new_graph = [&](std::vector child_graph_list) -> AnfNodePtr { + if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { + return child_graph_list[0]; + } + // create new child graph + auto child_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(child_graph); + // create new value node to bind child graph + auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); + std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), + graph_value_node}; + // set the graph id of all node of child graph + for (auto &child_graph_node : child_graph_list) { + AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); + } + SplitKernelGraph(child_graph, child_graph_list); + auto new_call = graph->NewCNode(new_call_input); + AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); + return new_call; + }; + if (child_graph_lists.size() > 1) { + for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { + auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); + if (call_index == 0) { + auto new_return_primitive = + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); + graph->set_return(graph->NewCNode({new_return_primitive, call_node})); + continue; + } + InsertDependToGraph(graph->graph_id(), call_node); + } + } + graph->UpdateChildGraphOrder(); + UpdateRealInput(graph.get()); + auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); + DumpIR(graph_name, graph); + MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; + // recurse to split child graph + for (auto &child_graph : graph->child_graph_order()) { + SplitGraph(child_graph); + } +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 6430691462..a916263e05 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -95,13 +95,16 @@ class AscendSession : public SessionBasic { void SetFinalGraphOutput(const ValuePtr &value); void SetFinalGraphOutput(const VectorRef &vec_output); - void SplitSwitch(KernelGraph *graph) {} + void SplitGraph(const KernelGraphPtr &graph); void LinkChildGraphs(KernelGraph *graph) {} void IRFusion(const KernelGraphPtr &graph) {} void SelectKernelGraphKernel(const KernelGraph &graph) {} void ConvertPredictModel(const KernelGraphPtr graph) {} void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} void RootGraphExecutorValidate(KernelGraph *graph) {} + void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph); + KernelGraphPtr SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, const std::vector &list); + void ChildGraphCommunicationDecrease(std::vector> *anf_node_lists); // merge execution order list of child graphs void MergeGraphExecOrder(); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index aebf738419..c6b84d57ad 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -16,9 +16,8 @@ #include "session/kernel_graph.h" #include #include -#include #include -#include "common/utils.h" +#include #include "operator/ops.h" #include "ir/param_value_py.h" #include "session/anf_runtime_algorithm.h" @@ -311,9 +310,10 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { // create kernel_build_info for new value node auto kernel_build_info_builder = std::make_shared(); // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + auto output_tensor_num = AnfAlgo::GetOutputTensorNum(value_node); + kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); // set value node initial device data type = infer data type - std::vector types = std::vector(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown); + std::vector types = std::vector(output_tensor_num, kTypeUnknown); kernel_build_info_builder->SetOutputsDeviceType(types); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); @@ -584,7 +584,25 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { } } -void KernelGraph::UpdateChildGraphOrder() {} +void KernelGraph::UpdateChildGraphOrder() { + MS_LOG(INFO) << "graph id:" << graph_id_; + auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + child_graph_order_.clear(); + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast()); + for (const auto &child_graph : call_child_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + if (child_graph != parent_graph()) { + child_graph->set_parent_graph(shared_from_this()->cast>()); + child_graph_order_.push_back(child_graph); + } + } + } + for (size_t i = 0; i < child_graph_order_.size(); i++) { + MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]"; + } +} std::vector> KernelGraph::GetLeafGraphOrder() { std::vector> leaf_graph_order; @@ -601,5 +619,36 @@ std::vector> KernelGraph::GetLeafGraphOrder() { } bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } + +std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { + auto anf_list = TopoSort(get_return()); + std::vector result; + for (const auto &anf : anf_list) { + if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { + result.push_back(anf->cast()); + } + } + return result; +} + +std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { + MS_EXCEPTION_IF_NULL(parameter); + if (real_inputs_.find(parameter) == real_inputs_.end()) { + return {}; + } + return real_inputs_[parameter]; +} + +void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { + MS_EXCEPTION_IF_NULL(parameter); + MS_EXCEPTION_IF_NULL(arg); + if (real_inputs_.find(parameter) == real_inputs_.end()) { + real_inputs_[parameter] = std::set(); + } + auto &args = real_inputs_[parameter]; + (void)args.insert(arg); +} + +std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 2fe9a9517b..9a3823532a 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "ir/func_graph.h" #include "ir/anf.h" @@ -113,6 +114,17 @@ class KernelGraph : public FuncGraph { } // get input_tensors pointer of control parameter std::shared_ptr> input_ctrl_tensors() const { return input_ctrl_tensors_; } + // get parent kernel graph + std::shared_ptr parent_graph() const { return parent_graph_; } + // set parent kernel graph + void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } + // find anf node in graph + std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; + // get real inputs + std::set GetRealInput(const AnfNodePtr ¶meter); + void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); + // used to dump ir + std::string ToString() const override; private: // remove value node form graph @@ -158,6 +170,10 @@ class KernelGraph : public FuncGraph { std::vector> child_graph_order_; // input_tensors of control parameter std::shared_ptr> input_ctrl_tensors_; + // parameter graph + std::shared_ptr parent_graph_; + // record real parameters,inputs_ is the formal parameters + std::map> real_inputs_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 16c1ab16ef..5aeda36230 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -247,27 +247,6 @@ std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool va return parameters; } -AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - if (!anf->isa()) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode"; - } - MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; - auto parameters = CreateParameterFromTuple(anf, valid_input, graph); - if (parameters.empty()) { - MS_LOG(EXCEPTION) << "No parameter exist!!"; - } - if (parameters.size() == 1) { - return parameters[0]; - } - std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; - (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); - auto make_tuple = graph->NewCNode(make_tuple_input); - MS_EXCEPTION_IF_NULL(make_tuple); - MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; - return make_tuple; -} - size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { MS_LOG(INFO) << "Load kInputCtrlTensors"; auto inputs_params = graph->input_ctrl_tensors(); @@ -390,6 +369,24 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf return new_parameter; } +AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; + auto parameters = CreateParameterFromTuple(anf, valid_input, graph); + if (parameters.empty()) { + MS_LOG(EXCEPTION) << "No parameter exist!!"; + } + if (parameters.size() == 1) { + return parameters[0]; + } + std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; + (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); + auto make_tuple = graph->NewCNode(make_tuple_input); + MS_EXCEPTION_IF_NULL(make_tuple); + MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; + return make_tuple; +} + CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, std::unordered_map *other_graph_cnode) { @@ -454,7 +451,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) MS_EXCEPTION_IF_NULL(attr_input); if (IsValueNode(attr_input)) { // create primitive of cnode:call - cnode_inputs = {std::make_shared(std::make_shared(kCallOpName))}; + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; // create a ValueNode as input of cnode:call if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); @@ -466,12 +463,10 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) } } else if (attr_input->isa()) { // create primitive of cnode:call(switch) - cnode_inputs = {std::make_shared(std::make_shared(kCallOpName))}; + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); - auto prim = GetCNodePrimitive(cnode_input); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() != kSwitchOpName) { + if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { MS_LOG(EXCEPTION) << "CNode input[0] must be switch."; } cnode_inputs.emplace_back(cnode_input); @@ -484,7 +479,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) auto prim = AnfAlgo::GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(prim); // push attr to inputs[0] of new cnode - cnode_inputs = {std::make_shared(std::make_shared(*prim))}; + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; } for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { @@ -545,7 +540,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); graph->FrontBackendlMapAdd(anf, new_value_node); - graph->AddValueNodeToGraph(new_value_node); return new_value_node; } @@ -555,11 +549,11 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph if (!anf->isa()) { MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; } - auto graph_inputs = graph->MutableInputs(); MS_EXCEPTION_IF_NULL(graph_inputs); - + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); auto new_parameter = graph->NewParameter(anf->cast()); + TraceManager::EndTrace(); graph_inputs->push_back(new_parameter); graph->FrontBackendlMapAdd(anf, new_parameter); diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index d172f89bbf..2719c9b67d 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -114,6 +114,7 @@ class SessionBasic { ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); + AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_;