|
|
|
@ -17,15 +17,16 @@
|
|
|
|
|
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <queue>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <stack>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <stack>
|
|
|
|
|
#include <atomic>
|
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
|
#include "ir/anf.h"
|
|
|
|
|
#include "ir/graph_utils.h"
|
|
|
|
@ -50,6 +51,51 @@ class KernelGraph : public FuncGraph {
|
|
|
|
|
summary_node_exist_ = false;
|
|
|
|
|
stream_distinction_label_ = kInvalidDistincLabel;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
|
|
|
|
|
inputs_ = graph.inputs_;
|
|
|
|
|
child_graph_result_ = graph.child_graph_result_;
|
|
|
|
|
execution_order_ = graph.execution_order_;
|
|
|
|
|
graph_id_ = graph.graph_id_;
|
|
|
|
|
stream_distinction_label_ = graph.stream_distinction_label_;
|
|
|
|
|
front_backend_anf_map_ = graph.front_backend_anf_map_;
|
|
|
|
|
backend_front_anf_map_ = graph.backend_front_anf_map_;
|
|
|
|
|
tensor_to_value_node_map_ = graph.tensor_to_value_node_map_;
|
|
|
|
|
graph_value_nodes_ = graph.graph_value_nodes_;
|
|
|
|
|
node_input_num_ = graph.node_input_num_;
|
|
|
|
|
node_input_edges_ = graph.node_input_edges_;
|
|
|
|
|
ref_out_in_map_ = graph.ref_out_in_map_;
|
|
|
|
|
node_output_edges_ = graph.node_output_edges_;
|
|
|
|
|
summary_nodes_ = graph.summary_nodes_;
|
|
|
|
|
executable_ = graph.executable_;
|
|
|
|
|
summary_node_exist_ = graph.summary_node_exist_;
|
|
|
|
|
valid_inputs_ = graph.valid_inputs_;
|
|
|
|
|
child_graph_order_ = graph.child_graph_order_;
|
|
|
|
|
input_ctrl_tensors_ = graph.input_ctrl_tensors_;
|
|
|
|
|
parent_graph_ = graph.parent_graph_;
|
|
|
|
|
start_label_ = graph.start_label_;
|
|
|
|
|
end_goto_ = graph.end_goto_;
|
|
|
|
|
null_output_ = graph.null_output_;
|
|
|
|
|
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
|
|
|
|
|
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
|
|
|
|
|
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
|
|
|
|
|
current_epoch_ = graph.current_epoch_;
|
|
|
|
|
tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_;
|
|
|
|
|
visited_nodes_ = graph.visited_nodes_;
|
|
|
|
|
edge_to_ = graph.edge_to_;
|
|
|
|
|
loop_nodes_ = graph.loop_nodes_;
|
|
|
|
|
input_nodes_ = graph.input_nodes_;
|
|
|
|
|
pre_graphs_ = graph.pre_graphs_;
|
|
|
|
|
post_graphs_ = graph.post_graphs_;
|
|
|
|
|
size_t pre_graph_finished_count = graph.pre_graph_finished_count_;
|
|
|
|
|
pre_graph_finished_count_ = pre_graph_finished_count;
|
|
|
|
|
size_t post_graph_finished_count = graph.post_graph_finished_count_;
|
|
|
|
|
post_graph_finished_count_ = post_graph_finished_count;
|
|
|
|
|
first_step_ = graph.first_step_;
|
|
|
|
|
has_optimizer_ = graph.has_optimizer_;
|
|
|
|
|
is_dynamic_shape_ = graph.is_dynamic_shape_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~KernelGraph() override;
|
|
|
|
|
|
|
|
|
|
MS_DECLARE_PARENT(KernelGraph, FuncGraph);
|
|
|
|
@ -189,6 +235,47 @@ class KernelGraph : public FuncGraph {
|
|
|
|
|
void SetInputNodes();
|
|
|
|
|
const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; }
|
|
|
|
|
bool has_optimizer() const { return has_optimizer_; }
|
|
|
|
|
// handle graph dependency
|
|
|
|
|
void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {
|
|
|
|
|
if (graph != nullptr) {
|
|
|
|
|
pre_graphs_[graph->graph_id()] = graph;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void AddPostGraph(const std::shared_ptr<session::KernelGraph> &graph) {
|
|
|
|
|
if (graph != nullptr) {
|
|
|
|
|
post_graphs_[graph->graph_id()] = graph;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; }
|
|
|
|
|
bool IsPostGraphFinished() {
|
|
|
|
|
if (first_step_) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return post_graphs_.size() == post_graph_finished_count_;
|
|
|
|
|
}
|
|
|
|
|
void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
|
|
|
|
|
void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
|
|
|
|
|
void ResetGraphRunningStatus() {
|
|
|
|
|
first_step_ = false;
|
|
|
|
|
post_graph_finished_count_ = 0;
|
|
|
|
|
pre_graph_finished_count_ = 0;
|
|
|
|
|
}
|
|
|
|
|
void OnRunGraphFinished() {
|
|
|
|
|
for (auto post_graph : post_graphs_) {
|
|
|
|
|
auto post_graph_ptr = post_graph.second.lock();
|
|
|
|
|
if (post_graph_ptr != nullptr) {
|
|
|
|
|
post_graph_ptr->IncPreGraphFinishedCount();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto pre_graph : pre_graphs_) {
|
|
|
|
|
auto pre_graph_ptr = pre_graph.second.lock();
|
|
|
|
|
if (pre_graph_ptr != nullptr) {
|
|
|
|
|
pre_graph_ptr->IncPostGraphFinishedCount();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// end of handle graph dependency
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// remove value node form graph
|
|
|
|
@ -218,6 +305,7 @@ class KernelGraph : public FuncGraph {
|
|
|
|
|
uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes);
|
|
|
|
|
void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num);
|
|
|
|
|
|
|
|
|
|
// members
|
|
|
|
|
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
|
|
|
|
|
std::vector<AnfNodePtr> child_graph_result_;
|
|
|
|
|
std::vector<CNodePtr> execution_order_;
|
|
|
|
@ -265,6 +353,11 @@ class KernelGraph : public FuncGraph {
|
|
|
|
|
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
|
|
|
|
|
std::stack<AnfNodePtr> loop_nodes_;
|
|
|
|
|
std::vector<AnfNodePtr> input_nodes_;
|
|
|
|
|
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
|
|
|
|
|
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
|
|
|
|
|
std::atomic<size_t> pre_graph_finished_count_{0};
|
|
|
|
|
std::atomic<size_t> post_graph_finished_count_{0};
|
|
|
|
|
bool first_step_{true};
|
|
|
|
|
bool has_optimizer_{false};
|
|
|
|
|
bool is_dynamic_shape_{false};
|
|
|
|
|
};
|
|
|
|
|