!8581 add graph dependency

From: @kisnwang
Reviewed-by: 
Signed-off-by:
pull/8581/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a6679511ed

@ -106,7 +106,11 @@ void BuildGraphTask::Run() {
void RunGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
try {
auto graph = session_->GetGraph(graph_id_);
MS_EXCEPTION_IF_NULL(graph);
graph->ResetGraphRunningStatus();
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
graph->OnRunGraphFinished();
UpdateOutputTensors(&outputs_, tensor_to_node_);
} catch (const std::exception &e) {
MsException::GetInstance().SetException();
@ -205,6 +209,7 @@ void Executor::OnRunGraphFinished() {
if (new_ready_tasks.size() > 0) {
task_cond_var_.notify_all();
}
reenter_cond_var_.notify_all();
}
bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
@ -215,6 +220,12 @@ bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
return false;
}
}
auto session = task->session_;
MS_EXCEPTION_IF_NULL(session);
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
return graph->IsPreGraphFinished();
}
return true;
}
@ -300,6 +311,14 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
SyncRunTask(task);
return;
}
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
if (!graph->IsPostGraphFinished()) {
mindspore::ScopedLongRunning long_running;
std::unique_lock<std::mutex> lock(reenter_mutex_);
reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); });
}
}
bool ready = IsTaskReady(task);
if (!ready) {

@ -179,8 +179,10 @@ class Executor {
std::string device_name_;
std::mutex task_mutex_;
std::mutex pending_task_mutex_;
std::mutex reenter_mutex_;
std::condition_variable task_cond_var_;
std::condition_variable sync_cond_var_;
std::condition_variable reenter_cond_var_;
std::queue<std::shared_ptr<Task>> ready_tasks_;
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
std::vector<std::shared_ptr<Task>> done_tasks_;

@ -17,15 +17,16 @@
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
#include <vector>
#include <unordered_map>
#include <memory>
#include <utility>
#include <string>
#include <queue>
#include <map>
#include <unordered_map>
#include <set>
#include <stack>
#include <unordered_set>
#include <stack>
#include <atomic>
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "ir/graph_utils.h"
@ -50,6 +51,51 @@ class KernelGraph : public FuncGraph {
summary_node_exist_ = false;
stream_distinction_label_ = kInvalidDistincLabel;
}
KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
inputs_ = graph.inputs_;
child_graph_result_ = graph.child_graph_result_;
execution_order_ = graph.execution_order_;
graph_id_ = graph.graph_id_;
stream_distinction_label_ = graph.stream_distinction_label_;
front_backend_anf_map_ = graph.front_backend_anf_map_;
backend_front_anf_map_ = graph.backend_front_anf_map_;
tensor_to_value_node_map_ = graph.tensor_to_value_node_map_;
graph_value_nodes_ = graph.graph_value_nodes_;
node_input_num_ = graph.node_input_num_;
node_input_edges_ = graph.node_input_edges_;
ref_out_in_map_ = graph.ref_out_in_map_;
node_output_edges_ = graph.node_output_edges_;
summary_nodes_ = graph.summary_nodes_;
executable_ = graph.executable_;
summary_node_exist_ = graph.summary_node_exist_;
valid_inputs_ = graph.valid_inputs_;
child_graph_order_ = graph.child_graph_order_;
input_ctrl_tensors_ = graph.input_ctrl_tensors_;
parent_graph_ = graph.parent_graph_;
start_label_ = graph.start_label_;
end_goto_ = graph.end_goto_;
null_output_ = graph.null_output_;
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
current_epoch_ = graph.current_epoch_;
tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_;
visited_nodes_ = graph.visited_nodes_;
edge_to_ = graph.edge_to_;
loop_nodes_ = graph.loop_nodes_;
input_nodes_ = graph.input_nodes_;
pre_graphs_ = graph.pre_graphs_;
post_graphs_ = graph.post_graphs_;
size_t pre_graph_finished_count = graph.pre_graph_finished_count_;
pre_graph_finished_count_ = pre_graph_finished_count;
size_t post_graph_finished_count = graph.post_graph_finished_count_;
post_graph_finished_count_ = post_graph_finished_count;
first_step_ = graph.first_step_;
has_optimizer_ = graph.has_optimizer_;
is_dynamic_shape_ = graph.is_dynamic_shape_;
}
~KernelGraph() override;
MS_DECLARE_PARENT(KernelGraph, FuncGraph);
@ -189,6 +235,47 @@ class KernelGraph : public FuncGraph {
void SetInputNodes();
const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; }
bool has_optimizer() const { return has_optimizer_; }
// handle graph dependency
void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {
if (graph != nullptr) {
pre_graphs_[graph->graph_id()] = graph;
}
}
void AddPostGraph(const std::shared_ptr<session::KernelGraph> &graph) {
if (graph != nullptr) {
post_graphs_[graph->graph_id()] = graph;
}
}
bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; }
bool IsPostGraphFinished() {
if (first_step_) {
return true;
}
return post_graphs_.size() == post_graph_finished_count_;
}
void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
void ResetGraphRunningStatus() {
first_step_ = false;
post_graph_finished_count_ = 0;
pre_graph_finished_count_ = 0;
}
void OnRunGraphFinished() {
for (auto post_graph : post_graphs_) {
auto post_graph_ptr = post_graph.second.lock();
if (post_graph_ptr != nullptr) {
post_graph_ptr->IncPreGraphFinishedCount();
}
}
for (auto pre_graph : pre_graphs_) {
auto pre_graph_ptr = pre_graph.second.lock();
if (pre_graph_ptr != nullptr) {
pre_graph_ptr->IncPostGraphFinishedCount();
}
}
}
// end of handle graph dependency
private:
// remove value node form graph
@ -218,6 +305,7 @@ class KernelGraph : public FuncGraph {
uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes);
void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num);
// members
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<AnfNodePtr> child_graph_result_;
std::vector<CNodePtr> execution_order_;
@ -265,6 +353,11 @@ class KernelGraph : public FuncGraph {
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
std::stack<AnfNodePtr> loop_nodes_;
std::vector<AnfNodePtr> input_nodes_;
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
std::atomic<size_t> pre_graph_finished_count_{0};
std::atomic<size_t> post_graph_finished_count_{0};
bool first_step_{true};
bool has_optimizer_{false};
bool is_dynamic_shape_{false};
};

@ -358,7 +358,7 @@ GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
auto it = graphs_.find(graph_id);
if (it == graphs_.end()) {
MS_LOG(WARNING) << "Can't find graph " << graph_id;
MS_LOG(INFO) << "Can't find graph " << graph_id;
return nullptr;
}
return it->second;

@ -57,11 +57,25 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std:
result.outputs = outputs;
result.graph_id = kInvalidGraphId;
GraphId graph_id = kInvalidGraphId;
auto current_session = target_sess_;
if (target != target_device_ && !target.empty()) {
CreateOtherSession(target);
graph_id = other_sess_->CompileGraph(segment, outputs);
} else {
graph_id = target_sess_->CompileGraph(segment, outputs);
current_session = other_sess_;
}
MS_EXCEPTION_IF_NULL(current_session);
graph_id = current_session->CompileGraph(segment, outputs);
segment->graph_id_ = graph_id;
auto graph = current_session->GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
for (auto &pre_segment : segment->pre_segments_) {
MS_EXCEPTION_IF_NULL(pre_segment);
auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
if (pre_graph == nullptr) {
pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
}
MS_EXCEPTION_IF_NULL(pre_graph);
pre_graph->AddPostGraph(graph);
graph->AddPreGraph(pre_graph);
}
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {

@ -246,6 +246,55 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
return result;
}
void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_target,
const std::map<AnfNodePtr, GraphSegmentPtr> &node_to_segment) {
std::stack<AnfNodePtr> to_visit;
std::map<AnfNodePtr, size_t> nodes_ref;
std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
CalcNodeRefCount(graph, &nodes_ref, &control_edges);
to_visit.push(graph->get_return());
while (!to_visit.empty()) {
auto &node = to_visit.top();
MS_EXCEPTION_IF_NULL(node);
to_visit.pop();
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
auto ctrl_inputs = control_edges.find(node);
if (ctrl_inputs != control_edges.end()) {
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
}
GraphSegmentPtr node_segment{nullptr};
auto node_iter = node_to_segment.find(node);
if (node_iter != node_to_segment.end()) {
node_segment = node_iter->second;
}
for (auto &input : node_inputs) {
if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) {
GraphSegmentPtr input_segment{nullptr};
auto input_iter = node_to_segment.find(input);
if (input_iter != node_to_segment.end()) {
input_segment = input_iter->second;
}
if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) {
node_segment->AddPreSegment(input_segment);
}
}
auto ref_iter = nodes_ref.find(input);
if (ref_iter != nodes_ref.end()) {
ref_iter->second--;
if (ref_iter->second != 0) {
continue;
}
}
to_visit.push(input);
}
}
}
std::vector<AnfNodePtr> ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
std::vector<AnfNodePtr> result;
std::stack<AnfNodePtr> handle_nodes;
@ -404,10 +453,10 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
auto nodes = TopoSort(graph->get_return());
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
bool contain_multi_target = ContainMultiTarget(nodes);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (contain_multi_target) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (graph != nullptr) {
nodes = SplitSort(graph, default_target);
} else {
@ -417,15 +466,22 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
}
std::vector<GraphSegmentPtr> segments;
std::vector<AnfNodePtr> segment_nodes;
std::map<AnfNodePtr, GraphSegmentPtr> node_to_segment;
auto new_segment = [&segments, &segment_nodes, &node_to_segment]() {
if (segment_nodes.size() != 0) {
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
segments.emplace_back(segment);
for (auto node : segment_nodes) {
node_to_segment[node] = segment;
}
segment_nodes.clear();
}
};
std::string last_target;
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
if (segment_nodes.size() != 0) {
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
segments.emplace_back(segment);
segment_nodes.clear();
}
new_segment();
segment_nodes.emplace_back(node);
auto segment = std::make_shared<GraphSegment>(segment_nodes, true);
segments.push_back(segment);
@ -433,10 +489,8 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
} else if (node->isa<CNode>()) {
if (contain_multi_target) {
std::string cur_target = GetCNodeTarget(node);
if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) {
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
segments.emplace_back(segment);
segment_nodes.clear();
if (cur_target != last_target && !last_target.empty()) {
new_segment();
}
last_target = cur_target;
}
@ -444,6 +498,9 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
}
}
MS_LOG(DEBUG) << "Segment size:" << segments.size();
if (contain_multi_target) {
AddSegmentDependency(graph, default_target, node_to_segment);
}
return segments;
}
} // namespace compile

@ -25,6 +25,7 @@
#include <memory>
#include <unordered_map>
#include <utility>
#include <set>
#include "base/base.h"
#include "base/user_data.h"
@ -490,8 +491,11 @@ std::string GetCNodeTarget(const AnfNodePtr &node);
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
struct GraphSegment {
GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); }
std::vector<AnfNodePtr> nodes_;
std::set<std::shared_ptr<GraphSegment>> pre_segments_;
bool is_cut_{false};
uint32_t graph_id_{0};
};
using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
} // namespace mindspore

Loading…
Cancel
Save