From cc54bb565db38a2c48079e65e280b907bd0d1a79 Mon Sep 17 00:00:00 2001 From: chenfei Date: Fri, 17 Apr 2020 19:37:17 +0800 Subject: [PATCH] move opt to build graph --- mindspore/ccsrc/kernel/kernel_build_info.cc | 12 +- mindspore/ccsrc/kernel/kernel_build_info.h | 3 + mindspore/ccsrc/kernel/kernel_query.cc | 6 +- mindspore/ccsrc/kernel/mng/rt_kernel_info.cc | 32 +- .../ascend/ascend_backend_optimization.cc | 12 +- .../common/common_backend_optimization.cc | 1 + .../ccsrc/session/anf_runtime_algorithm.cc | 28 +- mindspore/ccsrc/session/ascend_session.cc | 332 ++++++++++++------ mindspore/ccsrc/session/ascend_session.h | 23 +- mindspore/ccsrc/session/kernel_graph.cc | 49 ++- mindspore/ccsrc/session/kernel_graph.h | 13 +- mindspore/ccsrc/session/session_basic.cc | 19 +- mindspore/ccsrc/session/session_basic.h | 2 + mindspore/ops/_op_impl/tbe/assign.py | 1 + 14 files changed, 381 insertions(+), 152 deletions(-) diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index 279a62bad6..df855f5340 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -22,28 +22,32 @@ namespace mindspore { namespace kernel { std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { if (input_index >= inputs_format_.size()) { - MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node"; + MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node"; + return kInvalidFormat; } return inputs_format_[input_index]; } std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const { if (output_index >= outputs_format_.size()) { - MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node"; + MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node"; + return kInvalidFormat; } return outputs_format_[output_index]; } TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const { if (input_index >= inputs_device_type_.size()) { - MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node"; + MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input"; + return TypeId::kNumberTypeEnd; } return inputs_device_type_[input_index]; } TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { if (output_index >= outputs_device_type_.size()) { - MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of input node"; + MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output"; + return TypeId::kNumberTypeEnd; } return outputs_device_type_[output_index]; } diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h index 76ebc7a572..779be057f6 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ b/mindspore/ccsrc/kernel/kernel_build_info.h @@ -82,6 +82,9 @@ class KernelBuildInfo { bool operator==(const KernelBuildInfo &other) const; + public: + static auto constexpr kInvalidFormat = "InvalidFormat"; + private: KernelType kernel_type_; std::vector inputs_format_; diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 3d3282e7b5..e4a1af7f50 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace kernel { namespace { -void FilterInvaildKernelInfo(const CNodePtr &kernel_node, +void FilterInvalidKernelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_info_list); std::vector> filtered_list; @@ -63,9 +63,9 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorempty()) { - MS_LOG(EXCEPTION) << "op" << kernel_node->DebugString() << "kernel query fail!"; + MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; } - FilterInvaildKernelInfo(kernel_node, kernel_info_list); + FilterInvalidKernelInfo(kernel_node, kernel_info_list); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/mng/rt_kernel_info.cc b/mindspore/ccsrc/kernel/mng/rt_kernel_info.cc index a87bb4d514..cb230bc706 100755 --- a/mindspore/ccsrc/kernel/mng/rt_kernel_info.cc +++ b/mindspore/ccsrc/kernel/mng/rt_kernel_info.cc @@ -46,24 +46,40 @@ RtKerDescFactory &RtKerDescFactory::Get() { void GetRtKelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - MS_LOG(INFO) << "Mng kernel Info."; MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_node); std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node); (void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower); auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower); - if (ker_desc_ptr == nullptr) { - MS_LOG(DEBUG) << "Mng can't find op [" << opNameLower << "]."; + if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) { + *kernel_info_list = ker_desc_ptr->GetKernelInfo(); return; } - MS_EXCEPTION_IF_NULL(ker_desc_ptr); - auto kernel_info = ker_desc_ptr->GetKernelInfo(); - if (kernel_info.empty()) { - MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "]."; + // if can't find kernel info in kernel info database, use the default kernel info + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "StreamSwitch" || node_name == "StreamActive") { + auto kernel_build_info_builder = std::make_shared(); + // set input infos + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + kernel_build_info_builder->SetInputsFormat(std::vector(input_num, kOpFormat_DEFAULT)); + std::vector input_types = {}; + for (size_t i = 0; i < input_num; i++) { + input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i)); + } + kernel_build_info_builder->SetInputsDeviceType(input_types); + // set output info + auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + kernel_build_info_builder->SetOutputsFormat(std::vector(output_num, kOpFormat_DEFAULT)); + kernel_build_info_builder->SetOutputsDeviceType(std::vector(output_num, TypeId::kTypeUnknown)); + // set ohter info + kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); + kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); + kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); + kernel_info_list->push_back(kernel_build_info_builder->Build()); return; } - *kernel_info_list = kernel_info; + MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "]."; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 1b152c8998..66ea5ee526 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -186,7 +186,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrgraph_id()) + ".ir"; DumpIR(file_path, kernel_graph); DumpIRProto(kernel_graph, "before_hwopt"); } @@ -208,7 +209,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrOptimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir"; + std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" + + std::to_string(kernel_graph->graph_id()) + ".ir "; DumpIR(file_path, kernel_graph); } } @@ -252,7 +254,8 @@ void AscendBackendOptimization(const std::shared_ptr &kern save_graphs_path = "."; } if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_before.ir"; + std::string file_path = + save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; DumpIR(file_path, kernel_graph); } // data layout optimization @@ -278,7 +281,8 @@ void AscendBackendOptimization(const std::shared_ptr &kern (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_end.ir"; + std::string file_path = + save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; DumpIR(file_path, kernel_graph, true); DumpIRProto(kernel_graph, "after_hwopt"); } diff --git a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc index f622f2f06f..0383311122 100644 --- a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc @@ -27,6 +27,7 @@ namespace mindspore { namespace opt { void BackendCommonOptimization(const std::shared_ptr &kernel_graph) { + MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool save_graphs = context_ptr->save_graphs_flag(); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index dbf7097970..45588052b0 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -300,7 +300,12 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); - return build_info->GetOutputFormat(output_idx); + auto format = build_info->GetOutputFormat(output_idx); + if (format == kernel::KernelBuildInfo::kInvalidFormat) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid output format"; + } + return format; } std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { @@ -314,7 +319,12 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); - return build_info->GetInputFormat(input_idx); + auto format = build_info->GetInputFormat(input_idx); + if (format == kernel::KernelBuildInfo::kInvalidFormat) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid input format"; + } + return format; } KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { @@ -481,7 +491,12 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); - return build_info->GetOutputDeviceType(output_idx); + auto dtype = build_info->GetOutputDeviceType(output_idx); + if (dtype == TypeId::kNumberTypeEnd) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid dtype"; + } + return dtype; } TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { @@ -494,7 +509,12 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); - return build_info->GetInputDeviceType(input_idx); + auto dtype = build_info->GetInputDeviceType(input_idx); + if (dtype == TypeId::kNumberTypeEnd) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid dtype"; + } + return dtype; } TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index bd5fba6d4b..a0a9a108cc 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -15,6 +15,9 @@ */ #include "session/ascend_session.h" #include +#include +#include +#include #include "operator/ops.h" #include "ir/meta_tensor.h" #include "ir/anf.h" @@ -75,28 +78,15 @@ void DumpGraphInputArgs(const VectorRef &args) { void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { MS_EXCEPTION_IF_NULL(graph); - for (auto &node : graph->execution_order()) { - if (is_override || AnfAlgo::GetStreamDistinctionLabel(node.get()) == kInvalidDistincLabel) { - MS_EXCEPTION_IF_NULL(node); - AnfAlgo::SetStreamDistinctionLabel(label, node.get()); - } - } -} - -GraphId GetDistinctionLabel(const KernelGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - // if graph is empty,use graph id as distinction label - if (graph->execution_order().empty()) { - return graph->graph_id(); + if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) { + graph->set_stream_distinction_label(label); } - // else use first node of execution order as label - return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get()); } std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { MS_EXCEPTION_IF_NULL(graph); std::vector graph_inputs = graph->inputs(); - auto valid_inputs = graph->ValidInputs(); + auto valid_inputs = graph->valid_inputs(); size_t real_args_size = 0; std::vector real_args = {}; for (size_t i = 0; i < args.size(); i++) { @@ -141,23 +131,9 @@ std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { MS_LOG(INFO) << "start"; - auto graph_id = graph_sum_; // construct graph, if successfully, graph_sum_ + 1 auto graph = ConstructKernelGraph(lst, outputs); - MS_EXCEPTION_IF_NULL(graph); - opt::AscendBackendIRFusionOptimization(graph); - // select kernel build info - SelectKernel(*graph); - // convert kernel Graph to model - predictmodel::StepConvertGraph(graph); - // optimize graph - HardwareOptimize(graph); - // init runtime resource - InitRuntimeResource(); - // assign static memory of parameters - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->AssignStaticMemoryInput(graph.get()); + auto graph_id = graph->graph_id(); MS_LOG(INFO) << "Compile graph " << graph_id << " success"; return graph_id; } @@ -166,16 +142,36 @@ void AscendSession::BuildGraph(GraphId graph_id) { MS_LOG(INFO) << "start"; auto graph = GetGraph(graph_id); MS_EXCEPTION_IF_NULL(graph); + // resource initialize + InitRuntimeResource(); // multiple graph handle if (graph_id == final_graph_id_) { if (!graph->executable()) { return; } + // insert assigns to child graph + InsertAllAssigns(); + // insert switch and active to child graph + MergeSwitchCompile(); + // OptChildGraphs + auto graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + for (size_t i = 0; i < graph_order.size(); i++) { + if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { + continue; + } + MS_LOG(INFO) << "Start build child graph " << graph_order[i]; + auto child_graph = GetGraph(graph_order[i]); + CompileChildGraph(child_graph); + } // merge child graph MergeGraphExecOrder(); } else { + auto single_graph = GetGraph(graph_id); + CompileChildGraph(single_graph); // set the distinction label of single graph - SetStreamDistinctionLabel(GetGraph(graph_id), graph_id, false); + single_graph->set_stream_distinction_label(graph_id); + single_graph->UpdateExecuteKernelStreamLabel(); } // adjust execution order because merge child graph and other special operations AdjustKernel(graph); @@ -197,9 +193,26 @@ void AscendSession::BuildGraph(GraphId graph_id) { // load task info to device if it is sink mode LoadTask(graph); } + // sync the inital const tensor to device + SyncInitialTenosrToDevice(); MS_LOG(INFO) << "end"; } +void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { + MS_EXCEPTION_IF_NULL(child_graph); + opt::AscendBackendIRFusionOptimization(child_graph); + // select kernel build info + SelectKernel(*child_graph); + // convert kernel Graph to model + predictmodel::StepConvertGraph(child_graph); + // optimize graph + HardwareOptimize(child_graph); + // assign static memory of parameters + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignStaticMemoryInput(child_graph.get()); +} + void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *const outputs) { MS_LOG(INFO) << "start"; @@ -458,11 +471,9 @@ void AscendSession::Dump(const std::shared_ptr &kernel_graph) const GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { MS_LOG(INFO) << "Start! Args size " << args.size(); - auto final_graph = std::make_shared(); - final_graph_id_ = graph_sum_++; - graphs_[final_graph_id_] = final_graph; - final_graph->set_graph_id(final_graph_id_); - MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << "success"; + auto final_graph = NewKernelGraph(); + final_graph_id_ = final_graph->graph_id(); + MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success"; // init private variables and bind them with final_graph_id graph_execute_orders_[final_graph_id_] = std::vector(); graph_order_types_[final_graph_id_] = std::vector(); @@ -498,6 +509,46 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { return final_graph_id_; } +AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { + auto fake_graph = GetGraph(fake_graph_id); + auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); + auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { + auto parameter = fake_graph->NewParameter(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->set_abstract(abstract); + auto new_parameter = fake_graph->NewParameter(parameter); + // Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory. + auto graph_inputs = fake_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->push_back(new_parameter); + return new_parameter; + }; + auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr { + MS_EXCEPTION_IF_NULL(cnode); + auto abstract = cnode->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa()) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]"; + return create_parameter((*tuple_abstract)[output_idx]); + } + return create_parameter(cnode->abstract()); + }; + if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) { + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + auto make_tuple = output_item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + for (size_t i = 1; i < make_tuple->inputs().size(); i++) { + auto input = make_tuple->inputs()[i]; + make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input)); + } + return fake_graph->NewCNode(make_tuple_inputs); + } + return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); +} + void AscendSession::SetFinalGraphOutput(const BaseRef &output) { auto final_graph = GetGraph(final_graph_id_); MS_EXCEPTION_IF_NULL(final_graph); @@ -559,12 +610,6 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true condition_graph->AddValueNodeToGraph(counter_const); // create a new switch op auto switch_primitive = std::make_shared("StreamSwitch"); - auto kernel_build_info_builder = std::make_shared(); - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - kernel_build_info_builder->SetOutputsDeviceType(std::vector{kNumberTypeInt32}); - kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); - kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); - kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); auto cond_output_it = condition_output_.find(condition_graph_id); if (cond_output_it == condition_output_.end()) { MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; @@ -574,11 +619,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true MS_EXCEPTION_IF_NULL(cond_output_kernel); std::vector inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; CNodePtr switch_node = condition_graph->NewCNode(inputs); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get()); MS_EXCEPTION_IF_NULL(switch_node); switch_node->set_abstract(std::make_shared()); AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); - AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(GetGraph(condition_graph_id)), switch_node.get()); // set attr: cond_ RT_GREATER AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue(static_cast(RT_GREATER)), switch_node); // set attr:data_type @@ -586,9 +629,9 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true // set attr:true branch graph id ,which is same to stream distinction label AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(true_graph_id), switch_node); // append switch at the end of condition graph - std::vector exec_order = condition_graph->execution_order(); - exec_order.push_back(switch_node); - condition_graph->set_execution_order(exec_order); + auto return_node = condition_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + InsertControlDependToGraph(condition_graph_id, return_node->input(1), switch_node); MS_LOG(INFO) << "Finish!"; } @@ -615,8 +658,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { MS_EXCEPTION_IF_NULL(true_last); MS_EXCEPTION_IF_NULL(false_last); MS_LOG(INFO) << "The last graph of false branch is " << false_last_id; - // now only consider the single output - InsertMultipleAssignToGraph(true_last_id, true_last->output(), false_last->output()); + // create fake output + auto fake_output_graph = NewKernelGraph(); + graph_execute_order.push_back(fake_output_graph->graph_id()); + graph_order_type.push_back(COMMON_GRAPH); + fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output())); + final_graph->set_output(fake_output_graph->output()); + InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output()); + InsertMultipleAssignToGraph(false_last_id, false_last->output(), final_graph->output()); // insert stream active for loop sink auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -650,14 +699,14 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, if (false_graph_id != kInvalidGraphId) { // false graph and condition in graph same stream auto condition_graph = GetGraph(cond_graph_id); - SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true); + SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); // if false graph is a condition graph and has been switch compiled before,it's false should be updated again auto cond_it = switches_.find(false_graph_id); while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { cond_graph_id = cond_it->first; false_graph_id = cond_it->second.second; condition_graph = GetGraph(cond_graph_id); - SetStreamDistinctionLabel(GetGraph(false_graph_id), GetDistinctionLabel(condition_graph), true); + SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); cond_it = switches_.find(false_graph_id); } } @@ -691,7 +740,7 @@ void AscendSession::MergeSwitchCompile() { } // insert stream active to common graph if (prev_graph_id != kInvalidGraphId) { - InsertStreamActiveToGraph(prev_graph_id, GetDistinctionLabel(condition_graph)); + InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label()); } // if this is a 'if' condition auto it = while_condition_graphs_.find(cond_graph_id); @@ -700,12 +749,39 @@ void AscendSession::MergeSwitchCompile() { } else { // if it is a while,insert a stream active to true graph GraphId from_graph = it->second; - InsertStreamActiveToGraph(from_graph, GetDistinctionLabel(condition_graph)); + InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label()); } } MS_LOG(INFO) << "Finish!"; } +void AscendSession::InsertAllAssigns() { + std::set> assigns; + for (auto assign : assigns_) { + auto front_anf = std::get<0>(assign); + auto to_graph_id = std::get<1>(assign); + auto input_idx = std::get<2>(assign); + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + (void)assigns.insert(std::pair(front_anf, backend_parameter)); + } + // erase the repeat assign + for (auto &assign : assigns) { + auto front_anf = assign.first; + auto backend_parameter = assign.second; + auto from_graph_id = GetGraphIdByNode(front_anf); + auto from_graph = GetGraph(from_graph_id); + MS_EXCEPTION_IF_NULL(from_graph); + auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); + InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); + } +} + // insert active to graph void AscendSession::SetActive(GraphId from, GraphId to) { if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { @@ -735,20 +811,21 @@ void AscendSession::SetActive(GraphId from, GraphId to) { while_condition_graphs_[to] = from; } -void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter) { +void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) { MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(backend_parameter); MS_EXCEPTION_IF_NULL(front_anf); - if (!backend_parameter->isa()) { - MS_LOG(EXCEPTION) << "Backend parameter's type is not a parameter,but is " << backend_parameter->ToString(); - } auto from_graph_id = GetGraphIdByNode(front_anf); auto from_graph = GetGraph(from_graph_id); MS_EXCEPTION_IF_NULL(from_graph); - auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); auto to_graph = GetGraph(to_graph_id); - auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + MS_EXCEPTION_IF_NULL(backend_parameter); + auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) << "]"; @@ -759,39 +836,21 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An // if arg is the the parameter of child graph,it is parameter of final graph too if (front_anf->isa()) { MS_EXCEPTION_IF_NULL(backend_arg); - if (!AnfAlgo::OutputAddrExist(backend_arg, 0)) { - // set parameter's addr in child graph to parameter in final graph - AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_parameter, 0), 0, backend_arg.get()); - MS_LOG(INFO) << "Assign mem of node" << backend_parameter->DebugString() << " of graph " - << AnfAlgo::GetGraphId(backend_parameter.get()) << " to node" << backend_arg->DebugString() - << "of graph " << AnfAlgo::GetGraphId(backend_arg.get()); - return; - } - // if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device - // type same to arg - if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) { - AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get()); - } - // if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph - // unless it's a weight.If backend_parameter is a weight,we should assign the value back. - AnfAlgo::SetOutputAddr(AnfAlgo::GetMutableOutputAddr(backend_arg, 0), 0, backend_parameter.get()); + MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() + << "] will be replaced."; + to_graph->ReplaceNode(backend_parameter, backend_arg); return; } - InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); - MS_LOG(INFO) << "Finish!"; + MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" + << backend_parameter->DebugString() << "of graph " << to_graph_id; + (void)assigns_.insert(std::tuple(front_anf, to_graph_id, input_idx)); } -void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter) { +void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, + size_t input_idx) { MS_LOG(INFO) << "Start!"; - // sync data from host to device - MS_EXCEPTION_IF_NULL(front_tensor); - size_t tensor_size = front_tensor->data().nbytes(); - auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); - MS_EXCEPTION_IF_NULL(addr); - if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, - front_tensor->data_type(), front_tensor->data_c(false))) { - MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; - } + std::pair graph_input_pair(to_graph_id, input_idx); + initial_tenosrs_[graph_input_pair] = front_tensor; MS_LOG(INFO) << "Finish!"; } @@ -818,10 +877,9 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfN if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { return input_index + output_num; } - auto &graph_inputs = graph->inputs(); - auto &valid_inputs = graph->ValidInputs(); + auto valid_inputs = graph->valid_inputs(); if (valid_inputs[input_index]) { - SetChildGraphParameter(node, graph_inputs[input_index]); + SetChildGraphParameter(node, graph->graph_id(), input_index); } else { MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); } @@ -833,8 +891,7 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const Valu if (!value->isa()) { MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); } - auto &graph_inputs = graph->inputs(); - SetChildGraphParameter(value->cast(), graph_inputs[input_index]); + SetChildGraphParameter(value->cast(), graph->graph_id(), input_index); return ++input_index; } @@ -905,8 +962,6 @@ GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { void AscendSession::MergeGraphExecOrder() { MS_LOG(INFO) << "Start!"; - // insert switch to graph - MergeSwitchCompile(); // merge graph order auto &graph_order = GetGraphOrder(final_graph_id_); auto &graph_type = GetGraphOrderType(final_graph_id_); @@ -916,6 +971,13 @@ void AscendSession::MergeGraphExecOrder() { MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!"; return; } + if (graph_order.size() > 1) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->enable_task_sink()) { + MS_LOG(INFO) << "Control sink network should run with task-sink mode!"; + } + } // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph SetStreamDistinctionLabel(final_graph, graph_order[0], false); std::vector final_exec_order = final_graph->execution_order(); @@ -930,7 +992,11 @@ void AscendSession::MergeGraphExecOrder() { MS_EXCEPTION_IF_NULL(child_graph); auto exec_order = child_graph->execution_order(); MS_LOG(INFO) << "Merge graph,graph_id " << graph_id; - (void)std::copy(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order)); + (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order), + [&](CNodePtr node) -> CNodePtr { + AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get()); + return node; + }); // add all value nodes of child graphs to final graph for (auto &value_node : child_graph->graph_value_nodes()) { final_graph->AddValueNodeToGraph(value_node); @@ -969,15 +1035,9 @@ void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from // generate a new cnode auto assign_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(assign_node); - assign_node->set_abstract(std::make_shared()); - auto kernel_build_info_builder = std::make_shared(); - kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), assign_node.get()); - AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(graph), assign_node.get()); + assign_node->set_abstract(to->abstract()); // append the assign at the end of from graph - auto exec_order = graph->execution_order(); - exec_order.push_back(assign_node); - graph->set_execution_order(exec_order); + InsertDependToGraph(graph_id, assign_node); } void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { @@ -997,24 +1057,46 @@ void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodeP void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; - auto from_graph = graphs_[graph_id]; + auto from_graph = GetGraph(graph_id); MS_EXCEPTION_IF_NULL(from_graph); std::vector inputs = {NewValueNode(std::make_shared("StreamActive"))}; auto active_node = from_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(active_node); active_node->set_abstract(std::make_shared()); - auto kernel_build_info_builder = std::make_shared(); - kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), active_node.get()); // set the active stream id into the attr of active node std::vector active_index_value = {}; active_index_value.push_back(actived_stream); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_value), active_node); - AnfAlgo::SetStreamDistinctionLabel(GetDistinctionLabel(from_graph), active_node.get()); // append the active node at the end of from graph - auto exec_order = from_graph->execution_order(); - exec_order.push_back(active_node); - from_graph->set_execution_order(exec_order); + auto return_node = from_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + InsertControlDependToGraph(graph_id, return_node->input(1), active_node); +} + +void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { + MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString(); + auto graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(graph); + std::vector inputs = {NewValueNode(std::make_shared("depend"))}; + auto return_node = graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + inputs.push_back(return_node->input(1)); + inputs.push_back(attch_node); + auto depend_node = graph->NewCNode(inputs); + return_node->set_input(1, depend_node); +} + +void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, + const AnfNodePtr &second_node) { + MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() + << ", the second node is " << second_node->DebugString(); + auto graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(graph); + std::vector inputs = {NewValueNode(std::make_shared("ControlDepend"))}; + inputs.push_back(first_node); + inputs.push_back(second_node); + auto control_depend = graph->NewCNode(inputs); + InsertDependToGraph(graph_id, control_depend); } size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { @@ -1043,5 +1125,29 @@ std::vector &AscendSession::GetGraphOrderType(GraphId final_graph_id) } return graph_type_iter->second; } + +void AscendSession::SyncInitialTenosrToDevice() { + for (auto &item : initial_tenosrs_) { + auto to_graph_id = item.first.first; + auto input_idx = item.first.second; + auto front_tensor = item.second; + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + // sync data from host to device + MS_EXCEPTION_IF_NULL(front_tensor); + size_t tensor_size = front_tensor->data().nbytes(); + auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); + MS_EXCEPTION_IF_NULL(addr); + if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, + front_tensor->data_type(), front_tensor->data_c(false))) { + MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; + } + } +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 1ce236c9c3..635003d97c 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -21,6 +21,9 @@ #include #include #include +#include +#include +#include #include "session/session_basic.h" #include "session/kernel_graph.h" #include "kernel/kernel.h" @@ -60,6 +63,8 @@ class AscendSession : public SessionBasic { GraphId GetFinalRunGraph() const override { return final_graph_id_; } // insert active to graph void SetActive(GraphId, GraphId) override; + // compile child graph when session have multiple child graphs + void CompileChildGraph(const KernelGraphPtr &child_graph); private: void InitRuntimeResource(); @@ -95,12 +100,16 @@ class AscendSession : public SessionBasic { size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); // handle condition graph from vm void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); + // insert depend to graph, used to attch control nodes to graph + void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); + // insert depend to graph, used to attch control nodes to graph + void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); // Get graph by graph id ,if not exist return null ptr KernelGraphPtr GetGraph(GraphId graph_id); // set child graph parameter if front arg is a anf - void SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter); + void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); // set child graph parameter if front arg is a tensor - void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter); + void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx); // update the execution order of all child graphs void UpdateGraphOrder(GraphId to_graph); // handle switch when merge @@ -113,6 +122,12 @@ class AscendSession : public SessionBasic { void CopyOutputOfIf(GraphId false_graph_id); // check if graph cache exist bool GraphCacheExist(const GraphInfo &graph_info) const; + // insert all assign to child graph + void InsertAllAssigns(); + // create fake output of final graph + AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); + // sync intial tensors' data to device + void SyncInitialTenosrToDevice(); // member variables // key is final_graph_id,value is child graph execute order of final graph @@ -124,6 +139,10 @@ class AscendSession : public SessionBasic { // record all conditions std::unordered_map> switches_; std::unordered_map condition_output_; + // share parameters + std::set> assigns_; + // initial tensors, these tensor will sync data to device before run graph + std::map, tensor::TensorPtr> initial_tenosrs_; // final_graph_id is used in every root graph has it's own session situation GraphId final_graph_id_; }; diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index cdadf389a6..95ac38c405 100755 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -295,10 +295,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { // set the format of value_node to DEFAULT_FORMAT kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } + std::vector types = std::vector(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown); kernel_build_info_builder->SetOutputsDeviceType(types); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); @@ -330,10 +327,11 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons MS_LOG(EXCEPTION) << "old can't be same with new"; } if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { - MS_LOG(EXCEPTION) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; + MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; + return; } if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) { - MS_LOG(EXCEPTION) << "anf is not exist in the mape ,old " << old_backend_anf->DebugString(); + MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString(); } front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf; backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf]; @@ -528,5 +526,44 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { } return false; } + +void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) { + MS_EXCEPTION_IF_NULL(old_anf_node); + MS_EXCEPTION_IF_NULL(new_anf_node); + MS_EXCEPTION_IF_NULL(inputs_); + auto it = node_output_edges_.find(old_anf_node); + if (it == node_output_edges_.end()) { + MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map"; + } + auto &outputs = it->second; + for (auto &output_node : outputs) { + auto output_cnode = output_node.first->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto &output_node_inputs = output_cnode->inputs(); + for (size_t i = 1; i < output_node_inputs.size(); i++) { + if (output_node_inputs[i] == old_anf_node) { + output_cnode->set_input(i, new_anf_node); + } + } + // update graph inputs + for (size_t i = 0; i < inputs_->size(); i++) { + if ((*inputs_)[i] == old_anf_node) { + (*inputs_)[i] = new_anf_node; + break; + } + } + } + // update front to backend map + FrontBackendlMapUpdate(old_anf_node, new_anf_node); + // update output depend relations + node_output_edges_[new_anf_node] = it->second; + (void)node_output_edges_.erase(old_anf_node); +} + +void KernelGraph::UpdateExecuteKernelStreamLabel() { + for (auto &kernel : execution_order_) { + AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get()); + } +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 8cafcc2ebc..3425bde9c2 100755 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -27,6 +27,7 @@ #include "ir/func_graph.h" #include "ir/anf.h" #include "utils/graph_utils.h" +#include "device/kernel_info.h" namespace mindspore { namespace session { @@ -37,6 +38,7 @@ class KernelGraph : public FuncGraph { inputs_ = std::make_shared>(); execution_order_ = {}; executable_ = true; + stream_distinction_label_ = kInvalidDistincLabel; } ~KernelGraph() override = default; @@ -88,7 +90,15 @@ class KernelGraph : public FuncGraph { void set_executable(bool executable) { executable_ = executable; } // set invalid inputs for control sink std::vector *MutableValidInputs() { return &valid_inputs_; } - const std::vector &ValidInputs() const { return valid_inputs_; } + std::vector valid_inputs() const { return valid_inputs_; } + // replace node in graph + void ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node); + // set stream label of graph + void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } + // get stream label of graph + uint32_t stream_distinction_label() { return stream_distinction_label_; } + // refresh execute kernel stream label + void UpdateExecuteKernelStreamLabel(); private: // remove value node form graph @@ -108,6 +118,7 @@ class KernelGraph : public FuncGraph { std::shared_ptr> inputs_; std::vector execution_order_; uint32_t graph_id_; + uint32_t stream_distinction_label_; // record map bettween front anf and backend anf,use two map implement bidirectional map std::unordered_map front_backend_anf_map_; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 3436d68b81..5404ad6911 100755 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -417,9 +417,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { std::unordered_map other_graph_cnode; - auto graph = std::make_shared(); - graph->set_graph_id(graph_sum_); - MS_LOG(INFO) << "Create graph: " << graph_sum_; + auto graph = NewKernelGraph(); + MS_LOG(INFO) << "Create graph: " << graph->graph_id(); size_t from_other_graph_depend_num = 0; for (const auto &node : lst) { MS_EXCEPTION_IF_NULL(node); @@ -456,7 +455,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con } graph->SetExecOrderByDefault(); opt::BackendCommonOptimization(graph); - graphs_[graph_sum_++] = graph; return graph; } @@ -588,14 +586,14 @@ void SessionBasic::Summary(KernelGraph *graph) { CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector output_args; + for (const auto &output : outputs) { + MS_LOG(INFO) << "output:" << output->DebugString(); + } auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { auto backend_anf = graph->GetBackendAnfByFrontAnf(out); if (backend_anf != nullptr) { return backend_anf; } - for (const auto &output : outputs) { - MS_LOG(INFO) << "output:" << output->DebugString(); - } MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; }; output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); @@ -695,5 +693,12 @@ BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) { MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; } } + +KernelGraphPtr SessionBasic::NewKernelGraph() { + auto graph = std::make_shared(); + graph->set_graph_id(graph_sum_); + graphs_[graph_sum_++] = graph; + return graph; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 0fd0003cc9..de443833d6 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -104,6 +104,8 @@ class SessionBasic { const std::vector &tensors_mask); // trans BaseRef list to py::tuple BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); + // create a new kernel graph and update the graph sum + KernelGraphPtr NewKernelGraph(); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; diff --git a/mindspore/ops/_op_impl/tbe/assign.py b/mindspore/ops/_op_impl/tbe/assign.py index 41a9a0fecd..2fbd152c78 100644 --- a/mindspore/ops/_op_impl/tbe/assign.py +++ b/mindspore/ops/_op_impl/tbe/assign.py @@ -27,6 +27,7 @@ assign_op_info = TBERegOp("Assign") \ .input(1, "value", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \