/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GRAPH_MANAGER_GRAPH_MANAGER_H_ #define GE_GRAPH_MANAGER_GRAPH_MANAGER_H_ #include #include #include #include #include #include #include #include "common/blocking_queue.h" #include "common/ge_inner_error_codes.h" #include "common/helper/model_cache_helper.h" #include "external/graph/types.h" #include "ge/ge_api_types.h" #include "graph/build/graph_builder.h" #include "graph/execute/graph_execute.h" #include "graph/ge_local_context.h" #include "graph/load/graph_loader.h" #include "graph/manager/graph_manager_utils.h" #include "graph/manager/util/variable_accelerate_ctrl.h" #include "graph/optimize/graph_optimize.h" #include "graph/partition/graph_partition.h" #include "graph/preprocess/graph_preprocess.h" #include "graph/tuning_utils.h" #include "model/ge_model.h" namespace ge { class GraphManager { public: GraphManager(); ~GraphManager() = default; /// /// @ingroup ge_graph /// @brief graph manager init /// @param [in] options user config params /// @return Status result of function /// Status Initialize(const std::map &options); /// /// @ingroup ge_graph /// @brief graph manager finalize /// @return Status result of function /// Status Finalize(); /// /// @ingroup ge_graph /// @brief add specific graph /// @param [in] graph_id graph id /// @param [out] Graph output graph /// @return Status result of function /// Status AddGraph(const GraphId &graph_id, const Graph &graph, const std::map &options, const OmgContext &omg_context); Status InitDynamicParams(ComputeGraphPtr &compute_graph); /// /// @ingroup ge_graph /// @brief add a copy graph /// @param [in] graph_id graph id /// @param [out] Graph output graph /// @return Status result of function /// Status AddGraphWithCopy(const GraphId &graph_id, const Graph &graph, const std::map &options, const OmgContext &omg_context); /// /// @ingroup ge_graph /// @brief remove specific graph /// @param [in] graph_id graph id /// @return Status result of function /// Status RemoveGraph(const GraphId &graph_id); /// /// @ingroup ge_graph /// @brief run specific graph /// @param [in] graph_id graph id /// @param [in] inputs input data /// @param [out] outputs output data /// @return Status result of function /// Status RunGraph(const GraphId &graph_id, const std::vector &inputs, std::vector &outputs, uint64_t session_id = INVALID_SESSION_ID); /// /// @ingroup ge_graph /// @brief build specific graph /// @param [in] graph_id graph id /// @param [in] inputs input data /// @param [out] models build result /// @return Status result of function /// ge::Status BuildGraph(const GraphId &graph_id, const std::vector &inputs, GeRootModelPtr &models, uint64_t session_id = 0, bool async = false); Status BuildGraphForUnregisteredOp(const GraphId &graph_id, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id); /// /// @ingroup ge_graph /// @brief Save extra attribute to Model /// @param [in] model: Model attribues will save to. /// @param [in] type: type of OpDesc. /// @param [in] attrs: attributes of OpDesc /// @param [in] inputs: input tensor /// @param [in] outputs: output tensor /// @return: Status /// Status SaveParams(ge::GeModel &model, const std::string &type, const std::map &attrs, const std::vector &inputs, const std::vector &outputs); /// /// @ingroup ge_graph /// @brief get variable value from the session with specific session id /// @param [in] sessionId session id /// @param [in] name op name /// @param [out] val out value tensor /// @return Status result of function /// Status GetVariable(const std::string &name, Tensor &val); /// /// @ingroup ge_graph /// @brief run graph async on session with specific session id /// @param [in] graph_id graph id /// @param [in] inputs input data /// @param [out] callback: callback while run graph async finish /// @return Status result of function /// Status RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, uint64_t session_id, RunAsyncCallback callback); /// /// @ingroup ge_graph /// @brief me register the callback function to get the result of summary or checkpoin /// @param [in] key: summary or checkpoint /// @param [in] callbak: The real callback object of me /// @return Status result of function /// Status RegisterCallBackFunc( const std::string &key, const std::function &)> &callback); Status RegisterCallBackFunc( const std::string &key, const std::function &)> &callback); const bool GetTrainFlag() const { return options_.train_graph_flag; } bool IsGraphNeedRebuild(uint32_t graph_id); Status GenerateInfershapeGraph(GraphId &graph_id); const std::map *GetGraphOptions(uint32_t graph_id); void SetOptionsRunGraphFlag(bool run_graph_flag); Status GenCheckPointGraph(const std::map &all_variables, Graph &graph); Status SaveVariables(const Graph &graph, const std::vector &var_names, const std::vector &outputs, std::vector &var_values); Status SaveCheckPointResult(const Graph &graph, const std::vector &outputs, map &var_results); private: struct CompilerStages { GraphPrepare preparer; GraphOptimize optimizer; GraphPartitioner partitioner; GraphBuilder builder; }; struct PreRunArgs { GraphId graph_id; std::vector input_tensor; uint64_t session_id; struct ErrorMessage::Context error_context; GEThreadLocalContext context; RunAsyncCallback callback; }; struct RunArgs { GraphNodePtr graph_node; GraphId graph_id; uint64_t session_id; struct ErrorMessage::Context error_context; std::vector input_tensor; GeRootModelPtr ge_root_model; GEThreadLocalContext context; RunAsyncCallback callback; }; void AddGraphNode(GraphId graph_id, const GraphNodePtr &graph_node); void RemoveGraphNode(GraphId graph_id); bool HasGraphNode(GraphId graph_id); Status GetGraphNode(const GraphId &graph_id, GraphNodePtr &out); std::shared_ptr GetModelListener() const { return graph_run_listener_; } static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, const SubGraphInfoPtr &sub_graph_info_ptr, const std::string &root_graph_name, uint64_t session_id, const struct ErrorMessage::Context &error_context, const GEThreadLocalContext &ge_context); Status ParseInputsDims(const std::vector &input_tensor); void ParseInputsDimsForData(const std::vector &input_tensor); Status ParseInputsDimsForGetNexNosinkAndData(const vector &dynamic_nodes, const std::vector &input_tensor); Status RunCustomPass(const GraphNodePtr &graph_node); Status PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id = INVALID_SESSION_ID); Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, GeRootModelPtr &ge_root_model, uint64_t session_id); Status StartForRunGraph(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id = INVALID_SESSION_ID); Status InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector &inputs, std::vector &outputs); Status ParseOptions(const std::map &options); static void ParseOption(const std::map &options, const std::string &key, std::string &option); static Status ParseOption(const std::map &options, const std::string &key, bool &option); static Status ParseOption(const std::map &options, const std::string &key, int &option); static Status ParseOption(const std::map &options, const std::string &key, std::map &option); static void Trim(std::string &str); static Status CheckEngineName(const std::string &engine_name, const std::string &key, const std::map &option); static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); static Status ParseTrainGraphFlag(bool &options, bool &option); static bool IsPerfLevelInvalid(int32_t perf_level); Status SummaryHandle(const GraphId &graph_id, std::vector &outputs); Status CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, const std::vector &outputs); // call the callback function of ME to push summary result data to ME Status PushSummaryData2ME(const GraphId &graph_id, const std::map &summary_data); // call the callback function of ME to push save result data to ME Status PushSaveData2ME(const GraphId &graph_id, const std::map &save_data); bool IsCheckpointGraph(ComputeGraphPtr &compute_graph); bool CheckNetOutputForCheckpointGraph(NodePtr &node); bool CheckVariableForCheckpointGraph(NodePtr &node); bool CheckTransOpForCheckpointGraph(NodePtr &node); Status MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph, GraphId root_graph_id); Status ConvertGraphToFile(ComputeGraphPtr &compute_graph, GraphPartitioner &partitioner, std::string file_path, bool exe_flag = false); Status SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner); void SetAttrForHcomBroadCastOp(ge::ComputeGraphPtr &compute_graph); bool IsBroadCastOpData(const ge::NodePtr &var_node); void AdjustBroadCastOpData(const ge::NodePtr &var_node); bool IsAssignOpData(const ge::NodePtr &var_node); void AdjustAssignOpData(const ge::NodePtr &var_node); bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, const map> &confirm_ops, ge::NodePtr &use_node); bool ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, const map> &confirm_ops, ge::NodePtr &use_node); // graph context std::shared_ptr GetGraphContext() const { return graph_context_; } Status RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph); Status RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute_graph); Status OptimizeStage1(ComputeGraphPtr &compute_graph); Status OptimizeStage2(ComputeGraphPtr &compute_graph); Status SubexpressionMigration(ComputeGraphPtr &compute_graph); Status LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); bool CheckModelLoad(const GeRootModelPtr &ge_model, bool load_flag); Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); bool IsGraphNeedBuild(const GraphNodePtr &graph_node); Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); void RemoveModelCacheHelper(const GraphId &graph_id); ModelCacheHelperPtr FindModelCacheHelper(GraphId graph_id); static void ConstructGeInput(const std::vector &inputs, std::vector &ge_inputs); static void PreRunThread(GraphManager *graph_manager); static void RunThread(GraphManager *graph_manager); static void StopQueue(GraphManager *graph_manager); static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log); static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, Status ret, const string &log); void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector &inputs, ge::ComputeGraphPtr &compute_graph, uint64_t session_id); Status PreRunOptimizeSubGraph(const GraphNodePtr &graph_node, ge::ComputeGraphPtr &compute_graph, uint64_t session_id); Status PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, GeRootModelPtr &ge_root_model, uint64_t session_id); Status SetFuzzCompileFlag(ComputeGraphPtr &compute_graph); Status CopySubGraphAndMarkFusion(const ComputeGraphPtr &compute_graph, Graph2SubGraphInfoList &sub_graph_map, std::unordered_map ©_graphs); Status OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, Graph2SubGraphInfoList &sub_graph_map, uint64_t session_id); bool CheckAllFusionOptimizeSuccess(const ComputeGraphPtr &compute_graph, Graph2SubGraphInfoList &sub_graph_map); Status ReplaceSubgraphWithOriGraph(const ComputeGraphPtr &compute_graph, Graph2SubGraphInfoList &sub_graph_map, std::unordered_map ©_graphs); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint64_t session_id, uint32_t graph_id); void AddLocalOmgContext(GraphId graph_id, const OmgContext &omg_context); void UpdateLocalOmgContext(GraphId graph_id); CompilerStages &GetCompilerStages(GraphId graph_id); void RemoveCompilerStages(GraphId graph_id); std::atomic_bool thread_run_flag_; BlockingQueue prerun_args_q_{}; BlockingQueue run_args_q_{}; std::thread prerun_thread_; std::thread run_thread_; ComputeGraphPtr compute_graph_; std::map graph_map_; std::map cache_helper_map_; // for run graph synchronous return std::mutex sync_run_mutex_; std::condition_variable condition_; // run graph synchronization call back listener std::shared_ptr graph_run_listener_; // summary and checkpoint callback function list for ME, key is summary or checkpoint std::map &)>> me_callback_map_; std::map &)>> callback_map_; bool init_flag_; GraphManagerOptions options_; GraphContextPtr graph_context_ = nullptr; map omg_contexts_; map compiler_stages_; GraphExecutor graph_executor_; VarAccelerateCtrl var_acc_ctrl_; std::mutex run_mutex_; std::mutex member_mutex_; std::mutex unload_model_mutex_; }; } // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_H_