diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.h b/mindspore/ccsrc/device/ascend/ascend_device_address.h index 93746082c1..02b708e444 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.h +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.h @@ -35,6 +35,7 @@ class AscendDeviceAddress : public DeviceAddress { ~AscendDeviceAddress() override; bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } #ifdef ENABLE_DUMP_E2E bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, const std::vector &host_shape, TypeId host_type) const; diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 6ffa835204..b043f91a35 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -259,6 +259,15 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { return true; } +bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { + if (AnfAlgo::OutputAddrExist(kernel, index)) { + auto address = AnfAlgo::GetOutputAddr(kernel, index); + MS_EXCEPTION_IF_NULL(address); + return address->DeviceType() == DeviceAddressType::kAscend; + } + return false; +} + DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) { return std::make_shared(device_ptr, device_size, format, type_id); diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h index 336cfdc9f2..20526f66dc 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h @@ -45,6 +45,7 @@ class AscendKernelRuntime : public KernelRuntime { protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) override; + bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; bool SyncStream() override; private: diff --git a/mindspore/ccsrc/device/cpu/cpu_device_address.h b/mindspore/ccsrc/device/cpu/cpu_device_address.h index 9d51abe625..a041567f47 100644 --- a/mindspore/ccsrc/device/cpu/cpu_device_address.h +++ b/mindspore/ccsrc/device/cpu/cpu_device_address.h @@ -34,6 +34,7 @@ class CPUDeviceAddress : public DeviceAddress { bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } }; } // namespace cpu } // namespace device diff --git a/mindspore/ccsrc/device/device_address.h b/mindspore/ccsrc/device/device_address.h index fd3188e0f2..e02d231dd5 100644 --- a/mindspore/ccsrc/device/device_address.h +++ b/mindspore/ccsrc/device/device_address.h @@ -48,6 +48,7 @@ class GPUMemoryManager; namespace mindspore { namespace device { enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; +enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; class DeviceAddress { public: @@ -64,6 +65,7 @@ class DeviceAddress { TypeId type_id() const { return type_id_; } virtual void set_status(DeviceAddressStatus status) {} virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } + virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } protected: const void *ptr() const { return ptr_; } diff --git a/mindspore/ccsrc/device/gpu/gpu_device_address.h b/mindspore/ccsrc/device/gpu/gpu_device_address.h index f5c6b6e36b..4074cb6ce9 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_address.h +++ b/mindspore/ccsrc/device/gpu/gpu_device_address.h @@ -35,6 +35,7 @@ class GPUDeviceAddress : public DeviceAddress { bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; void set_status(DeviceAddressStatus status) { status_ = status; } DeviceAddressStatus status() const { return status_; } + DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } private: DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 9a8e65b474..d80c593e60 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -102,6 +102,13 @@ bool KernelRuntime::RunTask(const session::KernelGraph *graph) { return false; } +bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { + if (AnfAlgo::OutputAddrExist(kernel, index)) { + return true; + } + return false; +} + size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { MS_EXCEPTION_IF_NULL(node); if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { @@ -255,7 +262,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { if (i < graph_valid_input.size() && !graph_valid_input[i]) { continue; } - if (AnfAlgo::OutputAddrExist(item, 0)) { + if (NodeOutputDeviceAddressExist(item, 0)) { continue; } auto output_size = AnfAlgo::GetOutputTensorNum(item); @@ -431,7 +438,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { continue; } - if (AnfAlgo::OutputAddrExist(node, i)) { + if (NodeOutputDeviceAddressExist(node, i)) { MS_LOG(INFO) << "Already malloc index:" << i; continue; } @@ -493,7 +500,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(ms_context); for (auto &value_node : graph->graph_value_nodes()) { MS_EXCEPTION_IF_NULL(value_node); - if (AnfAlgo::OutputAddrExist(value_node, 0)) { + if (NodeOutputDeviceAddressExist(value_node, 0)) { MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; continue; } diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 668fb2580f..9fab81cffe 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -67,6 +67,7 @@ class KernelRuntime { protected: virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) = 0; + virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); virtual bool SyncStream() = 0; void AssignStaticMemory(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph); diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 3e87000be7..709697c922 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -305,17 +305,27 @@ bool TaskEmitAction(const ResourcePtr &res) { } FuncGraphPtr func_graph = res->func_graph(); auto bc_ptr = res->results()[kBackend].cast(); - if (IsCtrlSink()) { res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); return true; } - std::vector cut_list = compile::nonlinear_ops; if (bc_ptr->name() == kMsConvert) { cut_list = compile::GetMsNonlinearOps(); } + std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (compile->ContainMixedTarget(func_graph)) { + bc_ptr->set_is_multi_graph_sink(false); + context_ptr->set_loop_sink_flag(false); + } else if (context_ptr->execution_mode() != kPynativeMode) { + std::string device_target = context_ptr->device_target(); + if (device_target == kAscendDevice) { + bc_ptr->set_is_multi_graph_sink(true); + } + } res->results()[kOutput] = compile->CompileAndLink(func_graph); return true; } diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 103477363f..a436f4ed3d 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -775,7 +775,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc MS_EXCEPTION_IF_NULL(convert_fn); // Convert CNodeList to LinConvertResult. ConfigManager::GetInstance().set_iter_num(1); - auto runner = convert_fn({app_init}); + auto runner = convert_fn({app_init}, ""); if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { backend->Link(runner.graph_id); } diff --git a/mindspore/ccsrc/session/cpu_session.cc b/mindspore/ccsrc/session/cpu_session.cc index 32e3d8b6cc..447845480d 100644 --- a/mindspore/ccsrc/session/cpu_session.cc +++ b/mindspore/ccsrc/session/cpu_session.cc @@ -28,6 +28,23 @@ namespace mindspore { namespace session { +ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; + } + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + ParameterPtr new_parameter = graph->NewParameter(anf->cast()); + TraceManager::EndTrace(); + graph_inputs->push_back(new_parameter); + valid_inputs->push_back(valid_input); + return new_parameter; +} + GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { auto graph_id = graph_sum_; auto graph = ConstructKernelGraph(lst, outputs); diff --git a/mindspore/ccsrc/session/cpu_session.h b/mindspore/ccsrc/session/cpu_session.h index c53b0d2d8c..36b987e840 100644 --- a/mindspore/ccsrc/session/cpu_session.h +++ b/mindspore/ccsrc/session/cpu_session.h @@ -35,6 +35,9 @@ class CPUSession : public SessionBasic { GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + protected: + ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; + private: void SetKernelInfo(const KernelGraph *kernel_graph); void BuildKernel(const KernelGraph *kernel_graph); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index c1992b7cc0..902375c155 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -482,7 +482,13 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de depend_nodes = GetOutputNodes(depend_node); } for (auto &first_node : prior_nodes) { + if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { + continue; + } for (auto &second_node : depend_nodes) { + if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { + continue; + } MS_EXCEPTION_IF_NULL(first_node); MS_EXCEPTION_IF_NULL(second_node); MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 886e409854..dfd509925e 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -311,7 +311,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf if (python_paras_ == nullptr) { python_paras_ = std::make_shared>(); } - if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) { + if (python_paras_->find(m_tensor) != python_paras_->end()) { new_parameter = (*python_paras_)[m_tensor]; } else { TraceManager::DebugTrace(std::make_shared(anf->debug_info())); diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index b2e8c8894f..7c1a3fe966 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -114,7 +114,7 @@ class SessionBasic { BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); - ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 9a91f391c9..e7be85dc82 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -92,7 +92,7 @@ class MsContext { bool ir_fusion_flag() const { return ir_fusion_flag_; } bool loop_sink_flag() const { return enable_loop_sink_; } - + void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; } void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } bool enable_mem_reuse() const { return enable_mem_reuse_; } diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 0fac84d901..e8af57d764 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -39,14 +39,14 @@ LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { multi_result_.inputs = g->parameters(); final_output_ = NewValueNode("fake_output"); multi_result_.outputs = {final_output_}; - GraphId final_g = sess_->GetFinalRunGraph(); + GraphId final_g = target_sess_->GetFinalRunGraph(); multi_result_.run = std::make_shared( - [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); }); + [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); }); return multi_result_; } -LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { +LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { MS_LOG(DEBUG) << "MsConvert"; MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); auto cached = g_ConvertCache.find(lst); @@ -64,17 +64,24 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { result.inputs = inputs; result.outputs = outputs; result.graph_id = kInvalidGraphId; - auto graph_id = sess_->CompileGraph(lst, outputs); - if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { - sess_->BuildGraph(graph_id); + GraphId graph_id = kInvalidGraphId; + if (target == kCPUDevice) { + graph_id = cpu_sess_->CompileGraph(lst, outputs); + } else { + graph_id = target_sess_->CompileGraph(lst, outputs); } + if (MsContext::GetInstance()->precompile_only()) { MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; } - + if (target == kCPUDevice) { + cpu_sess_->BuildGraph(graph_id); + } else if (!is_multi_graph_sink_) { + target_sess_->BuildGraph(graph_id); + } result.run = std::make_shared( - [graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); }); + [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); MS_EXCEPTION_IF_NULL(result.run); result.simu_run = std::make_shared( @@ -92,7 +99,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { GraphId cond_g = kInvalidGraphId; if (utils::isa(c)) { - cond_g = sess_->GetGraphIdByNode(utils::cast(c)); + cond_g = target_sess_->GetGraphIdByNode(utils::cast(c)); } else { MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString(); } @@ -116,7 +123,7 @@ void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { MS_LOG(DEBUG) << "invoke set active:" << active_g; } MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g; - sess_->SetActive(active_g, cond_g); + target_sess_->SetActive(active_g, cond_g); } void MsBackend::SetSwitchGraph() { @@ -135,12 +142,12 @@ void MsBackend::SetSwitchGraph() { } GraphId cond_g = kInvalidGraphId; if (utils::isa(curr_switch_)) { - cond_g = sess_->GetGraphIdByNode(utils::cast(curr_switch_)); + cond_g = target_sess_->GetGraphIdByNode(utils::cast(curr_switch_)); } else { MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); } MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; - sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast(curr_switch_)); + target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast(curr_switch_)); } is_switch_call_ = false; MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; @@ -202,7 +209,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef old_args[i] = args[it->second]; } } - sess_->SetChildGraphInput(graph, old_args); + target_sess_->SetChildGraphInput(graph, old_args); } graph_inputs_.erase(c); } @@ -211,7 +218,7 @@ void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { MS_LOG(DEBUG) << "set graph input:" << g; // switch maybe twice - sess_->SetChildGraphInput(g, args); + target_sess_->SetChildGraphInput(g, args); if (is_switch_call_) { if (!curr_switch_.is_null()) { @@ -236,7 +243,7 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { return VectorRef(outputs); } -VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { +VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; // Run graph std::vector inputs; @@ -271,7 +278,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { VectorRef outputs; // call ms rungraph (graphId, input ,output) - sess_->RunGraph(g, inputs, &outputs); + if (target == kCPUDevice) { + cpu_sess_->RunGraph(g, inputs, &outputs); + } else { + target_sess_->RunGraph(g, inputs, &outputs); + } + MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); return outputs; } @@ -300,17 +312,17 @@ void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) { (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args), [](const AnfNodePtr &v) { return v; }); MS_LOG(DEBUG) << "Simulate start"; - (void)sess_->SetFinalGraphInput(parameters); + (void)target_sess_->SetFinalGraphInput(parameters); BaseRef output = rt->Eval(VectorRef(args)); - sess_->SetFinalGraphOutput(output); + target_sess_->SetFinalGraphOutput(output); MS_LOG(DEBUG) << "Simulate Eval end"; } void MsBackend::Link(GraphId graph_id) { if (graph_id == kInvalidGraphId) { - graph_id = sess_->GetFinalRunGraph(); + graph_id = target_sess_->GetFinalRunGraph(); } - sess_->BuildGraph(graph_id); + target_sess_->BuildGraph(graph_id); } Backend::Backend(const std::string &name) : name_(name) { @@ -322,16 +334,26 @@ Backend::Backend(const std::string &name) : name_(name) { } MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { - convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1); - sess_ = session::SessionFactory::Get().Create(target); - if (sess_ == nullptr) { + convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); + target_sess_ = session::SessionFactory::Get().Create(target); + if (target_sess_ == nullptr) { MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; } - sess_->Init(device_id); - sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); + target_sess_->Init(device_id); + target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); + if (target == kCPUDevice) { + cpu_sess_ = target_sess_; + } else { + cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice); + if (cpu_sess_ == nullptr) { + MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << "."; + } + cpu_sess_->Init(0); + cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); + } } -GraphId MsBackend::CompileGraph(NotNull fg) { return sess_->CompileGraph(fg); } +GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_->CompileGraph(fg); } VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 94b7a500e2..1ff345a43b 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -91,8 +91,8 @@ class MsBackend : public Backend { MsBackend(const std::string &name, const std::string &target, uint32_t device_id); ~MsBackend() override = default; - LinConvertResult MsConvert(const AnfNodePtrList &lst); - VectorRef MsRunGraph(const GraphId &g, const VectorRef &args); + LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = ""); + VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override; @@ -109,7 +109,8 @@ class MsBackend : public Backend { VectorRef RunGraph(GraphId graph_id, const VectorRef &args); private: - session::SessionPtr sess_; + session::SessionPtr target_sess_; + session::SessionPtr cpu_sess_; std::unordered_map simu_cond_map_; std::unordered_map graph_id_map_; std::unordered_map>, BaseRefHash> graph_inputs_; diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index ae052770ff..dcd62a548d 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -148,7 +148,7 @@ std::tuple TransformSegmentToAnfGr // This implementation will convert the nodes into a subgraph // that will run using the MsVM. template -LinConvertResult Convert(const AnfNodePtrList &lst) { +LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { auto cached = g_ConvertCache.find(lst); if (cached != g_ConvertCache.end()) { return cached->second; diff --git a/mindspore/ccsrc/vm/segment_runner.h b/mindspore/ccsrc/vm/segment_runner.h index 8ea87da50c..c4458d4148 100644 --- a/mindspore/ccsrc/vm/segment_runner.h +++ b/mindspore/ccsrc/vm/segment_runner.h @@ -43,7 +43,7 @@ struct LinConvertResult { uint32_t graph_id; }; -using LinkFuncType = std::function; +using LinkFuncType = std::function; using ConvertCache = std::unordered_map; extern LinkFuncType MsVmConvert; extern LinkFuncType GeVmConvert; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 636d36f931..90efc0ac5f 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -20,6 +20,8 @@ #include #include +#include +#include #include #include @@ -47,6 +49,86 @@ const std::vector &GetMsNonlinearOps() { return ms_nonlinear_ops; } +namespace { +std::string GetCNodeTarget(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + if (!node->isa()) { + return default_target; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = cnode->input(kAnfPrimitiveIndex); + if (attr_input == nullptr) { + return default_target; + } + auto value_node = attr_input->cast(); + if (value_node == nullptr) { + return default_target; + } + auto value = value_node->value(); + if (value == nullptr) { + return default_target; + } + if (!value->isa()) { + return default_target; + } + auto primitive = value->cast(); + ValuePtr att_target = primitive->GetAttr("target"); + if (att_target != nullptr) { + std::string target = GetValue(att_target); + return target; + } + return default_target; +} + +bool ContainMultiTarget(const std::vector &nodes) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string last_target = context_ptr->device_target(); + for (auto &node : nodes) { + if (node->isa()) { + std::string cur_target = GetCNodeTarget(node); + if (last_target != cur_target) { + return true; + } + last_target = cur_target; + } + } + return false; +} + +void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref) { + std::queue queue; + queue.push(graph->get_return()); + std::set visited; + while (!queue.empty()) { + auto &node = queue.front(); + queue.pop(); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (auto &input : cnode->inputs()) { + auto iter = nodes_ref->find(input); + if (iter != nodes_ref->end()) { + iter->second++; + } else { + (void)nodes_ref->insert(std::pair(input, 1)); + } + if (visited.find(input) != visited.end()) { + continue; + } + visited.insert(input); + queue.push(input); + } + } +} +} // namespace + CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend), cut_list_(cut_list) { MS_EXCEPTION_IF_NULL(backend_); @@ -98,12 +180,67 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { return false; } +std::vector CompileGraph::SplitSort(const FuncGraphPtr &graph) { + std::vector result; + std::queue queue; + std::queue next_queue; + std::map nodes_ref; + CalcNodeRefCount(graph, &nodes_ref); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string queue_target = context_ptr->device_target(); + std::string next_target = ""; + queue.push(graph->get_return()); + while (!queue.empty() || !next_queue.empty()) { + if (queue.empty()) { + queue.swap(next_queue); + queue_target = next_target; + } + auto &node = queue.front(); + queue.pop(); + MS_EXCEPTION_IF_NULL(node); + result.emplace_back(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (auto &input : cnode->inputs()) { + auto iter = nodes_ref.find(input); + if (iter != nodes_ref.end()) { + iter->second--; + if (iter->second != 0) { + continue; + } + } + if (!input->isa()) { + queue.push(input); + continue; + } + std::string input_target = GetCNodeTarget(input); + if (input_target == queue_target) { + queue.push(input); + } else if (next_queue.empty() || input_target == next_target) { + next_queue.push(input); + next_target = input_target; + } else { + MS_LOG(EXCEPTION) << "only support two different target"; + } + } + } + std::reverse(result.begin(), result.end()); + return result; +} + VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); VectorRef splits; VectorRef split; - std::vector nodes = TopoSort(graph->get_return()); - + auto nodes = TopoSort(graph->get_return()); + if (ContainMultiTarget(nodes)) { + nodes = SplitSort(graph); + } + std::string last_target; MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); @@ -114,7 +251,13 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { } splits.push_back(node); split.clear(); - } else if (!(node->isa() || node->isa())) { + } else if (node->isa()) { + std::string cur_target = GetCNodeTarget(node); + if (cur_target != last_target && !last_target.empty() && split.size() != 0) { + splits.push_back(split); + split.clear(); + } + last_target = cur_target; split.push_back(node); MS_LOG(DEBUG) << "Insert node:" << node->DebugString(10) << ", size:" << split.size(); } @@ -200,14 +343,14 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { } } -int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) { +int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, const std::string &target) { MS_LOG(DEBUG) << "LinConvert start"; LinConvertResult result; if (backend_->simu_flag()) { result = backend_->GetMultiGraphRun(graph); } else { - result = lin_convert_(node_list); + result = lin_convert_(node_list, target); } if (result.run == nullptr) { @@ -316,7 +459,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { auto vec_ref = utils::cast(split); (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), [](const BaseRef &v) { return utils::cast(v); }); - ret = LinConvert(graph, args); + if (args.size() > 0) { + std::string cur_target = GetCNodeTarget(args[0]); + ret = LinConvert(graph, args, cur_target); + } else { + ret = LinConvert(graph, args); + } MS_LOG(DEBUG) << "End a extern LinConvert"; if (ret == RET_FAILED) { return false; @@ -637,6 +785,19 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { return rt; } +bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) { + auto graph_manager = graph->manager(); + MS_EXCEPTION_IF_NULL(graph_manager); + FuncGraphSet graphs = graph_manager->func_graphs(); + for (auto &g : graphs) { + auto nodes = TopoSort(g->get_return()); + if (ContainMultiTarget(nodes)) { + return true; + } + } + return false; +} + BackendPtr CreateBackend() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 711c1777ab..7505a52ed1 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -79,8 +79,9 @@ class CompileGraph { private: void PushParameters(const FuncGraphPtr &func_graph); + std::vector SplitSort(const FuncGraphPtr &graph); bool SplitGraph(const FuncGraphPtr &func_graph); - int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list); + int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); void AddSinkSwitch(const CNodePtr &node); @@ -124,6 +125,7 @@ class CompileGraphs { void Compile(const FuncGraphPtr &func_graph); FinalVMPtr Link(const FuncGraphPtr &func_graph); FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); + bool ContainMixedTarget(const FuncGraphPtr &graph); private: InstSet insts_; diff --git a/tests/ut/cpp/vm/segment_runner_test.cc b/tests/ut/cpp/vm/segment_runner_test.cc index f08272c728..b9bc552d90 100644 --- a/tests/ut/cpp/vm/segment_runner_test.cc +++ b/tests/ut/cpp/vm/segment_runner_test.cc @@ -65,7 +65,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } - auto convertResult = MsVmConvert(anf_list); + auto convertResult = MsVmConvert(anf_list, ""); auto runResult = (*(convertResult.run))(args); ASSERT_TRUE(runResult.size() == 1 && py::cast(BaseRefToPyData(runResult[0])) == 3.0); } @@ -89,7 +89,7 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } - auto convertResult = MsVmConvert(anf_list); + auto convertResult = MsVmConvert(anf_list, ""); auto runResult = (*(convertResult.run))(args); ASSERT_TRUE(runResult.size() == 1 && py::cast(BaseRefToPyData(runResult[0])) == 2.0); } @@ -113,7 +113,7 @@ TEST_F(TestCompileSegmentRunner, test_if) { for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } - auto convertResult = MsVmConvert(anf_list); + auto convertResult = MsVmConvert(anf_list, ""); auto runResult = (*(convertResult.run))(args); auto result = py::cast(BaseRefToPyData(runResult[0]));