From 1ce4bebbf2d43766bae2bfaf490c4182283d4be9 Mon Sep 17 00:00:00 2001 From: kswang Date: Mon, 16 Nov 2020 16:29:36 +0800 Subject: [PATCH] add subgraph dependency --- mindspore/ccsrc/backend/session/executor.cc | 19 ++++ mindspore/ccsrc/backend/session/executor.h | 2 + .../ccsrc/backend/session/kernel_graph.h | 97 ++++++++++++++++++- .../ccsrc/backend/session/session_basic.cc | 2 +- mindspore/ccsrc/vm/backend.cc | 20 +++- mindspore/ccsrc/vm/graph_partition.cc | 81 +++++++++++++--- mindspore/core/ir/anf.h | 4 + 7 files changed, 207 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index c732513870..9da6390152 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -106,7 +106,11 @@ void BuildGraphTask::Run() { void RunGraphTask::Run() { MS_EXCEPTION_IF_NULL(session_); try { + auto graph = session_->GetGraph(graph_id_); + MS_EXCEPTION_IF_NULL(graph); + graph->ResetGraphRunningStatus(); session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); + graph->OnRunGraphFinished(); UpdateOutputTensors(&outputs_, tensor_to_node_); } catch (const std::exception &e) { MsException::GetInstance().SetException(); @@ -205,6 +209,7 @@ void Executor::OnRunGraphFinished() { if (new_ready_tasks.size() > 0) { task_cond_var_.notify_all(); } + reenter_cond_var_.notify_all(); } bool Executor::IsTaskReady(const std::shared_ptr &task) { @@ -215,6 +220,12 @@ bool Executor::IsTaskReady(const std::shared_ptr &task) { return false; } } + auto session = task->session_; + MS_EXCEPTION_IF_NULL(session); + auto graph = session->GetGraph(task->graph_id_); + if (graph != nullptr) { + return graph->IsPreGraphFinished(); + } return true; } @@ -300,6 +311,14 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, SyncRunTask(task); return; } + auto graph = session->GetGraph(task->graph_id_); + if (graph != nullptr) { + if (!graph->IsPostGraphFinished()) { + mindspore::ScopedLongRunning long_running; + std::unique_lock lock(reenter_mutex_); + reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); + } + } bool ready = IsTaskReady(task); if (!ready) { diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index d57193f8c6..48fc7553e8 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -179,8 +179,10 @@ class Executor { std::string device_name_; std::mutex task_mutex_; std::mutex pending_task_mutex_; + std::mutex reenter_mutex_; std::condition_variable task_cond_var_; std::condition_variable sync_cond_var_; + std::condition_variable reenter_cond_var_; std::queue> ready_tasks_; std::list> pending_tasks_; std::vector> done_tasks_; diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 3d552cc228..266dcb767c 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -17,15 +17,16 @@ #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H #include -#include #include #include #include #include #include +#include #include -#include #include +#include +#include #include "ir/func_graph.h" #include "ir/anf.h" #include "ir/graph_utils.h" @@ -50,6 +51,51 @@ class KernelGraph : public FuncGraph { summary_node_exist_ = false; stream_distinction_label_ = kInvalidDistincLabel; } + + KernelGraph(const KernelGraph &graph) : FuncGraph(graph) { + inputs_ = graph.inputs_; + child_graph_result_ = graph.child_graph_result_; + execution_order_ = graph.execution_order_; + graph_id_ = graph.graph_id_; + stream_distinction_label_ = graph.stream_distinction_label_; + front_backend_anf_map_ = graph.front_backend_anf_map_; + backend_front_anf_map_ = graph.backend_front_anf_map_; + tensor_to_value_node_map_ = graph.tensor_to_value_node_map_; + graph_value_nodes_ = graph.graph_value_nodes_; + node_input_num_ = graph.node_input_num_; + node_input_edges_ = graph.node_input_edges_; + ref_out_in_map_ = graph.ref_out_in_map_; + node_output_edges_ = graph.node_output_edges_; + summary_nodes_ = graph.summary_nodes_; + executable_ = graph.executable_; + summary_node_exist_ = graph.summary_node_exist_; + valid_inputs_ = graph.valid_inputs_; + child_graph_order_ = graph.child_graph_order_; + input_ctrl_tensors_ = graph.input_ctrl_tensors_; + parent_graph_ = graph.parent_graph_; + start_label_ = graph.start_label_; + end_goto_ = graph.end_goto_; + null_output_ = graph.null_output_; + front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_; + internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_; + internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_; + current_epoch_ = graph.current_epoch_; + tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_; + visited_nodes_ = graph.visited_nodes_; + edge_to_ = graph.edge_to_; + loop_nodes_ = graph.loop_nodes_; + input_nodes_ = graph.input_nodes_; + pre_graphs_ = graph.pre_graphs_; + post_graphs_ = graph.post_graphs_; + size_t pre_graph_finished_count = graph.pre_graph_finished_count_; + pre_graph_finished_count_ = pre_graph_finished_count; + size_t post_graph_finished_count = graph.post_graph_finished_count_; + post_graph_finished_count_ = post_graph_finished_count; + first_step_ = graph.first_step_; + has_optimizer_ = graph.has_optimizer_; + is_dynamic_shape_ = graph.is_dynamic_shape_; + } + ~KernelGraph() override; MS_DECLARE_PARENT(KernelGraph, FuncGraph); @@ -189,6 +235,47 @@ class KernelGraph : public FuncGraph { void SetInputNodes(); const std::vector &input_nodes() const { return input_nodes_; } bool has_optimizer() const { return has_optimizer_; } + // handle graph dependency + void AddPreGraph(const std::shared_ptr &graph) { + if (graph != nullptr) { + pre_graphs_[graph->graph_id()] = graph; + } + } + void AddPostGraph(const std::shared_ptr &graph) { + if (graph != nullptr) { + post_graphs_[graph->graph_id()] = graph; + } + } + + bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; } + bool IsPostGraphFinished() { + if (first_step_) { + return true; + } + return post_graphs_.size() == post_graph_finished_count_; + } + void IncPreGraphFinishedCount() { pre_graph_finished_count_++; } + void IncPostGraphFinishedCount() { post_graph_finished_count_++; } + void ResetGraphRunningStatus() { + first_step_ = false; + post_graph_finished_count_ = 0; + pre_graph_finished_count_ = 0; + } + void OnRunGraphFinished() { + for (auto post_graph : post_graphs_) { + auto post_graph_ptr = post_graph.second.lock(); + if (post_graph_ptr != nullptr) { + post_graph_ptr->IncPreGraphFinishedCount(); + } + } + for (auto pre_graph : pre_graphs_) { + auto pre_graph_ptr = pre_graph.second.lock(); + if (pre_graph_ptr != nullptr) { + pre_graph_ptr->IncPostGraphFinishedCount(); + } + } + } + // end of handle graph dependency private: // remove value node form graph @@ -218,6 +305,7 @@ class KernelGraph : public FuncGraph { uint32_t GetLoopNum(std::map none_zero_nodes); void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num); + // members std::shared_ptr> inputs_; std::vector child_graph_result_; std::vector execution_order_; @@ -265,6 +353,11 @@ class KernelGraph : public FuncGraph { std::map edge_to_; std::stack loop_nodes_; std::vector input_nodes_; + std::unordered_map> pre_graphs_; + std::unordered_map> post_graphs_; + std::atomic pre_graph_finished_count_{0}; + std::atomic post_graph_finished_count_{0}; + bool first_step_{true}; bool has_optimizer_{false}; bool is_dynamic_shape_{false}; }; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index d507bf56e2..79167d3625 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -344,7 +344,7 @@ void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { auto it = graphs_.find(graph_id); if (it == graphs_.end()) { - MS_LOG(WARNING) << "Can't find graph " << graph_id; + MS_LOG(INFO) << "Can't find graph " << graph_id; return nullptr; } return it->second; diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 3ea63e41fc..961905d76b 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -57,11 +57,25 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std: result.outputs = outputs; result.graph_id = kInvalidGraphId; GraphId graph_id = kInvalidGraphId; + auto current_session = target_sess_; if (target != target_device_ && !target.empty()) { CreateOtherSession(target); - graph_id = other_sess_->CompileGraph(segment, outputs); - } else { - graph_id = target_sess_->CompileGraph(segment, outputs); + current_session = other_sess_; + } + MS_EXCEPTION_IF_NULL(current_session); + graph_id = current_session->CompileGraph(segment, outputs); + segment->graph_id_ = graph_id; + auto graph = current_session->GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(graph); + for (auto &pre_segment : segment->pre_segments_) { + MS_EXCEPTION_IF_NULL(pre_segment); + auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_); + if (pre_graph == nullptr) { + pre_graph = other_sess_->GetGraph(pre_segment->graph_id_); + } + MS_EXCEPTION_IF_NULL(pre_graph); + pre_graph->AddPostGraph(graph); + graph->AddPreGraph(pre_graph); } if (MsContext::GetInstance()->get_param(MS_CTX_PRECOMPILE_ONLY)) { diff --git a/mindspore/ccsrc/vm/graph_partition.cc b/mindspore/ccsrc/vm/graph_partition.cc index 8384ee190f..d98970124c 100644 --- a/mindspore/ccsrc/vm/graph_partition.cc +++ b/mindspore/ccsrc/vm/graph_partition.cc @@ -246,6 +246,55 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & return result; } +void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_target, + const std::map &node_to_segment) { + std::stack to_visit; + std::map nodes_ref; + std::map> control_edges; + CalcNodeRefCount(graph, &nodes_ref, &control_edges); + to_visit.push(graph->get_return()); + while (!to_visit.empty()) { + auto &node = to_visit.top(); + MS_EXCEPTION_IF_NULL(node); + to_visit.pop(); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto node_inputs = cnode->inputs(); + auto ctrl_inputs = control_edges.find(node); + if (ctrl_inputs != control_edges.end()) { + node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); + } + GraphSegmentPtr node_segment{nullptr}; + auto node_iter = node_to_segment.find(node); + if (node_iter != node_to_segment.end()) { + node_segment = node_iter->second; + } + for (auto &input : node_inputs) { + if (node_segment != nullptr && !node_segment->is_cut_ && input->isa()) { + GraphSegmentPtr input_segment{nullptr}; + auto input_iter = node_to_segment.find(input); + if (input_iter != node_to_segment.end()) { + input_segment = input_iter->second; + } + if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) { + node_segment->AddPreSegment(input_segment); + } + } + auto ref_iter = nodes_ref.find(input); + if (ref_iter != nodes_ref.end()) { + ref_iter->second--; + if (ref_iter->second != 0) { + continue; + } + } + to_visit.push(input); + } + } +} + std::vector ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) { std::vector result; std::stack handle_nodes; @@ -404,10 +453,10 @@ std::vector GraphPartition::Partition(const FuncGraphPtr &graph auto nodes = TopoSort(graph->get_return()); MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); bool contain_multi_target = ContainMultiTarget(nodes); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); if (contain_multi_target) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); if (graph != nullptr) { nodes = SplitSort(graph, default_target); } else { @@ -417,15 +466,22 @@ std::vector GraphPartition::Partition(const FuncGraphPtr &graph } std::vector segments; std::vector segment_nodes; + std::map node_to_segment; + auto new_segment = [&segments, &segment_nodes, &node_to_segment]() { + if (segment_nodes.size() != 0) { + auto segment = std::make_shared(segment_nodes, false); + segments.emplace_back(segment); + for (auto node : segment_nodes) { + node_to_segment[node] = segment; + } + segment_nodes.clear(); + } + }; std::string last_target; for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (IsCut(node)) { - if (segment_nodes.size() != 0) { - auto segment = std::make_shared(segment_nodes, false); - segments.emplace_back(segment); - segment_nodes.clear(); - } + new_segment(); segment_nodes.emplace_back(node); auto segment = std::make_shared(segment_nodes, true); segments.push_back(segment); @@ -433,10 +489,8 @@ std::vector GraphPartition::Partition(const FuncGraphPtr &graph } else if (node->isa()) { if (contain_multi_target) { std::string cur_target = GetCNodeTarget(node); - if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) { - auto segment = std::make_shared(segment_nodes, false); - segments.emplace_back(segment); - segment_nodes.clear(); + if (cur_target != last_target && !last_target.empty()) { + new_segment(); } last_target = cur_target; } @@ -444,6 +498,9 @@ std::vector GraphPartition::Partition(const FuncGraphPtr &graph } } MS_LOG(DEBUG) << "Segment size:" << segments.size(); + if (contain_multi_target) { + AddSegmentDependency(graph, default_target, node_to_segment); + } return segments; } } // namespace compile diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 2ad81d61f0..d4539c86dc 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -25,6 +25,7 @@ #include #include #include +#include #include "base/base.h" #include "base/user_data.h" @@ -485,8 +486,11 @@ std::string GetCNodeTarget(const AnfNodePtr &node); bool ContainMultiTarget(const std::vector &nodes); struct GraphSegment { GraphSegment(const std::vector &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {} + void AddPreSegment(const std::shared_ptr &segment) { (void)pre_segments_.insert(segment); } std::vector nodes_; + std::set> pre_segments_; bool is_cut_{false}; + uint32_t graph_id_{0}; }; using GraphSegmentPtr = std::shared_ptr; } // namespace mindspore