From 730869f49517b9ac79b1371d41941cb0be4d60de Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Thu, 1 Apr 2021 11:33:02 +0800 Subject: [PATCH] connect the process of actor runtime --- mindspore/ccsrc/pipeline/jit/action.cc | 52 +++++++++++++ .../runtime/framework/actor/kernel_actor.cc | 4 +- .../runtime/framework/host_tensor_queue.h | 3 +- mindspore/ccsrc/vm/CMakeLists.txt | 2 + mindspore/ccsrc/vm/backend.cc | 53 +++++++++++++ mindspore/ccsrc/vm/backend.h | 24 ++++++ mindspore/ccsrc/vm/transform.cc | 75 ++++++++++++++++++- mindspore/ccsrc/vm/transform.h | 24 ++++++ 8 files changed, 233 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index dcc4b785c9..d2b4e7b18e 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -50,6 +50,43 @@ namespace mindspore { namespace pipeline { +namespace { +void TaskEmitActionForMindRT(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + // Get the mindRT backend. + auto bc_ptr = res->results()[kBackend].cast(); + auto mindrt_bc_ptr = std::dynamic_pointer_cast(bc_ptr); + MS_EXCEPTION_IF_NULL(mindrt_bc_ptr); + + auto cut_list = compile::GetMsNonlinearOps(); + auto mindrt_compile = std::make_shared(mindrt_bc_ptr, cut_list); + // The output of graph compiler is graph id. + res->results()[kOutput] = mindrt_compile->CompileGraphs(res->func_graph()); +} + +void ExecuteActionForMindRT(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + if (!res->results()[kOutput].is()) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + auto graph_id = res->results()[kOutput].cast(); + + // Get the mindRT backend. + std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); + auto mindrt_bc_ptr = (std::dynamic_pointer_cast(bc_ptr)).get(); + MS_EXCEPTION_IF_NULL(mindrt_bc_ptr); + + // Construct the graph run function ptr. + compile::VmEvalFuncPtr run = + std::make_shared([mindrt_bc_ptr, graph_id](const VectorRef &args) -> BaseRef { + MS_LOG(INFO) << "Execute args size " << args.size(); + auto outs = mindrt_bc_ptr->RunGraph(graph_id, args); + MS_LOG(DEBUG) << "out size " << outs.size(); + return outs[0]; + }); + res->results()[kOutput] = run; +} +} // namespace using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; @@ -488,6 +525,13 @@ bool TaskEmitAction(const ResourcePtr &res) { } } + // The graph compiling of mindRT. + if ((backend == kMsConvert) && compile::IsMindRTUsed()) { + TaskEmitActionForMindRT(res); + return true; + } + + // The graph compiling of control sink. if (IsCtrlSink() && backend == kMsConvert) { res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); return true; @@ -510,6 +554,14 @@ bool ExecuteAction(const ResourcePtr &res) { MS_LOG(EXCEPTION) << "Execute args error"; } std::string backend = MsContext::GetInstance()->backend_policy(); + + // The graph running of mindRT. + if ((backend == kMsConvert) && compile::IsMindRTUsed()) { + ExecuteActionForMindRT(res); + return true; + } + + // The graph running of control sink. if (IsCtrlSink() && backend == kMsConvert) { if (!res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; diff --git a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc index 939b5a2bd6..383458bc6e 100644 --- a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc @@ -25,7 +25,7 @@ void KernelActor::RunOpData(OpDataPtr input_data, OpContextsequential_num_; input_op_datas_[sequential_num].emplace_back(input_data); - // When all the input data are collected, then allocate memory and callback launch. + // When all the inputs are collected, then allocate memory and callback launch. if (CheckLaunchCondition(context)) { FetchInputDeviceTensor(context); FetchOutputDeviceTensor(); @@ -38,7 +38,7 @@ void KernelActor::RunOpControl(AID *input_control, OpContext *cont MS_EXCEPTION_IF_NULL(context); auto sequential_num = context->sequential_num_; input_op_controls_[sequential_num].emplace_back(input_control); - // When all the input data are collected, then allocate memory and callback launch. + // When all the inputs are collected, then allocate memory and callback launch. if (CheckLaunchCondition(context)) { FetchInputDeviceTensor(context); FetchOutputDeviceTensor(); diff --git a/mindspore/ccsrc/runtime/framework/host_tensor_queue.h b/mindspore/ccsrc/runtime/framework/host_tensor_queue.h index 5618f5df2f..ae9284600e 100644 --- a/mindspore/ccsrc/runtime/framework/host_tensor_queue.h +++ b/mindspore/ccsrc/runtime/framework/host_tensor_queue.h @@ -26,7 +26,8 @@ namespace mindspore { namespace runtime { using mindspore::tensor::TensorPtr; -// Host tensor queue is used to store host tensors, and its data will be fetched by the host queue data source actor. +// Host tensor queue is used to store host tensors(such as non weighted parameters of graph), and its data will be +// fetched by the host queue data source actor. class HostTensorQueue { public: HostTensorQueue() = default; diff --git a/mindspore/ccsrc/vm/CMakeLists.txt b/mindspore/ccsrc/vm/CMakeLists.txt index 9031c1515d..600a589df2 100644 --- a/mindspore/ccsrc/vm/CMakeLists.txt +++ b/mindspore/ccsrc/vm/CMakeLists.txt @@ -1,3 +1,5 @@ +include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/include) + file(GLOB_RECURSE _VM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_VM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_VM) add_library(_mindspore_vm_obj OBJECT ${_VM_SRC_LIST}) diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 689aff264d..fa8c257a2e 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -26,6 +26,9 @@ #include "utils/convert_utils.h" #include "utils/log_adapter.h" #include "utils/ms_utils.h" +#include "runtime/hardware/device_context_manager.h" +#include "runtime/framework/graph_compiler.h" +#include "runtime/framework/graph_scheduler.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" #endif @@ -221,8 +224,58 @@ void MsBackend::ClearSessionGraphs() { target_sess_->ClearGraph(); } } + #ifdef ENABLE_DEBUGGER void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } #endif + +MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id) + : Backend(backend_name), device_name_(device_name), device_id_(device_id) {} + +GraphId MindRTBackend::CompileGraph(const AnfNodePtrList &nodes) { + // Get and set the device context. + const auto &cur_device_name = GetCNodeTarget(nodes[0]); + const auto &device_context = + device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); + runtime::GraphCompiler::GetInstance().set_device_context(device_context); + + // Transform nodes to inputs and outputs. + FuncGraphPtr fg; + AnfNodePtrList inputs; + AnfNodePtrList outputs; + std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(nodes); + + // Compile graph. + return runtime::GraphCompiler::GetInstance().CompileGraph(inputs, outputs); +} + +VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) { + const auto &context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_PRECOMPILE_ONLY)) { + MS_LOG(INFO) << "PrecompileOnly, stop run graph"; + return VectorRef(); + } + + // Transform args to input tensors. + std::vector inputs; + for (const auto &arg : args) { + PushInputTensor(arg, &inputs); + } + + // Fetch the kernel graph. + const auto &kernel_graph = runtime::GraphCompiler::GetInstance().Fetch(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + + // Fetch the actor DAG. + const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(kernel_graph); + MS_EXCEPTION_IF_NULL(actor_set); + + // Run actor DAG, wait interface of GraphScheduler to create outputs. + VectorRef outputs; + runtime::GraphScheduler::GetInstance().Run(actor_set); + + return outputs; +} } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 45f1d7f996..07876df33f 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "utils/contract.h" #include "ir/anf.h" @@ -31,6 +32,8 @@ namespace mindspore { namespace compile { +using OpRunInfo = session::OpRunInfo; + enum SwitchCondStatus { kCondOk = 0, kCondAlreadyRun, @@ -85,6 +88,27 @@ class MsBackend : public Backend { std::string other_device_; std::unordered_map graph_id_map_; }; + +class MindRTBackend : public Backend { + public: + MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id); + ~MindRTBackend() override = default; + + // Compile kernel graph from anf nodes list in the graph mode. + GraphId CompileGraph(const AnfNodePtrList &nodes); + // Compile single op kernel graph in the pyNative mode. + GraphId CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask); + + // Run Graph in the graph mode. + VectorRef RunGraph(GraphId graph_id, const VectorRef &args); + // Run Graph in the pyNative mode. + VectorRef RunGraph(const GraphInfo &graph_info, const VectorRef &args); + + private: + std::string device_name_; + uint32_t device_id_; +}; } // namespace compile } // namespace mindspore #endif diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 11f090de1d..aada50449e 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -521,6 +521,73 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { return rt; } +GraphCompiler::GraphCompiler(const std::shared_ptr &backend, const std::vector &cut_list) + : backend_(backend) { + MS_EXCEPTION_IF_NULL(backend_); + if (backend_ == nullptr) { + MS_LOG(ERROR) << "The backend isn't created."; + return; + } + graph_partition_ = std::make_shared(cut_list, backend->name()); +} + +uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphPtr root_graph = WrapPrimitives(func_graph); + MS_EXCEPTION_IF_NULL(root_graph); + + // Compile root graph. + auto root_graph_id = CompileGraph(root_graph); + + // Compile sub graphs. + FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); + for (auto sub_graph : sub_graphs) { + if (sub_graph != func_graph && sub_graph != nullptr && !(sub_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { + (void)CompileGraph(sub_graph); + } + } + + return root_graph_id; +} + +uint32_t GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(graph_partition_); + MS_EXCEPTION_IF_NULL(backend_); + + // Split graph to segments. + const auto &segments = graph_partition_->Partition(func_graph); + MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size(); + + // Foreach the segments to compile graph. + std::vector graph_ids; + for (const auto &segment : segments) { + MS_EXCEPTION_IF_NULL(segment); + // Compile the normal nodes, which doesn't contain the cut node. + if (!segment->is_cut_) { + if (segment->nodes_.size() == 0) { + MS_LOG(EXCEPTION) << "The segments size is 0."; + } + MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope(); + + // Compile the anfNodes list to kernelGraph, return the graph id of kernelGraph. + auto graph_id = backend_->CompileGraph(segment->nodes_); + graph_ids.emplace_back(graph_id); + } else { + // Compile the cut node. + auto cut_node = segment->nodes_[0]; + MS_EXCEPTION_IF_NULL(cut_node); + MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope(); + } + } + + return graph_ids[0]; +} + +// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future. +// Return false in the transitional stage. +bool IsMindRTUsed() { return false; } + BackendPtr CreateBackend() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -533,7 +600,13 @@ BackendPtr CreateBackend() { if (name == kMsConvert) { std::string target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); - auto backend = std::make_shared(name, target, device_id); + BackendPtr backend = nullptr; + // Create MindRTBackend or MsBackend according to whether mindrt is used. + if (IsMindRTUsed()) { + backend = std::make_shared(name, target, device_id); + } else { + backend = std::make_shared(name, target, device_id); + } std::string device_target = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET); if (device_target == kAscendDevice) { if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 249ae94d58..f4d098b5e0 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -131,6 +131,30 @@ class CompileGraphs { BackendPtr backend_; }; +// The graph compiling of using mindRT, which transforms the funcGraph to kernelGraph and returns the graph id of +// kernelGraph. +class GraphCompiler { + public: + GraphCompiler(const std::shared_ptr &backend, + const std::vector &cut_list = nonlinear_ops); + ~GraphCompiler() = default; + + // The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, + // the return is the kernelGraph id of the root graph. It will traverse all subgraphs to call CompileGraph. + uint32_t CompileGraphs(const FuncGraphPtr &root_graph); + + private: + // The parameter func_graph is a graph, it can be either a root graph or a sub graph, + // the return is the corresponding kernelGraph id of the graph. + uint32_t CompileGraph(const FuncGraphPtr &func_graph); + + std::shared_ptr backend_; + GraphPartitionPtr graph_partition_; +}; + +// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future. +bool IsMindRTUsed(); + BackendPtr CreateBackend(); } // namespace compile