|
|
|
@ -21,6 +21,9 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <stack>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include "session/session_basic.h"
|
|
|
|
|
#include "session/kernel_graph.h"
|
|
|
|
|
#include "kernel/kernel.h"
|
|
|
|
@ -60,6 +63,8 @@ class AscendSession : public SessionBasic {
|
|
|
|
|
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
|
|
|
|
|
// insert active to graph
|
|
|
|
|
void SetActive(GraphId, GraphId) override;
|
|
|
|
|
// compile child graph when session have multiple child graphs
|
|
|
|
|
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void InitRuntimeResource();
|
|
|
|
@ -95,12 +100,16 @@ class AscendSession : public SessionBasic {
|
|
|
|
|
size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
|
|
|
|
|
// handle condition graph from vm
|
|
|
|
|
void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
|
|
|
|
|
// insert depend to graph, used to attch control nodes to graph
|
|
|
|
|
void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
|
|
|
|
|
// insert depend to graph, used to attch control nodes to graph
|
|
|
|
|
void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
|
|
|
|
|
// Get graph by graph id ,if not exist return null ptr
|
|
|
|
|
KernelGraphPtr GetGraph(GraphId graph_id);
|
|
|
|
|
// set child graph parameter if front arg is a anf
|
|
|
|
|
void SetChildGraphParameter(const AnfNodePtr &front_anf, const AnfNodePtr &backend_parameter);
|
|
|
|
|
void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
|
|
|
|
|
// set child graph parameter if front arg is a tensor
|
|
|
|
|
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, const AnfNodePtr &backend_parameter);
|
|
|
|
|
void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
|
|
|
|
|
// update the execution order of all child graphs
|
|
|
|
|
void UpdateGraphOrder(GraphId to_graph);
|
|
|
|
|
// handle switch when merge
|
|
|
|
@ -113,6 +122,12 @@ class AscendSession : public SessionBasic {
|
|
|
|
|
void CopyOutputOfIf(GraphId false_graph_id);
|
|
|
|
|
// check if graph cache exist
|
|
|
|
|
bool GraphCacheExist(const GraphInfo &graph_info) const;
|
|
|
|
|
// insert all assign to child graph
|
|
|
|
|
void InsertAllAssigns();
|
|
|
|
|
// create fake output of final graph
|
|
|
|
|
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
|
|
|
|
|
// sync intial tensors' data to device
|
|
|
|
|
void SyncInitialTenosrToDevice();
|
|
|
|
|
|
|
|
|
|
// member variables
|
|
|
|
|
// key is final_graph_id,value is child graph execute order of final graph
|
|
|
|
@ -124,6 +139,10 @@ class AscendSession : public SessionBasic {
|
|
|
|
|
// record all conditions
|
|
|
|
|
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
|
|
|
|
|
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
|
|
|
|
|
// share parameters
|
|
|
|
|
std::set<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
|
|
|
|
|
// initial tensors, these tensor will sync data to device before run graph
|
|
|
|
|
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
|
|
|
|
|
// final_graph_id is used in every root graph has it's own session situation
|
|
|
|
|
GraphId final_graph_id_;
|
|
|
|
|
};
|
|
|
|
|