/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_ #define GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_ #include #include #include #include #include #include #include #include #include "common/blocking_queue.h" #include "common/ge_types.h" #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" #include "graph/graph.h" #include "graph/model.h" #include "model/ge_model.h" #include "model/ge_root_model.h" #include "register/register_fmk_types.h" #include "external/ge/ge_api_types.h" namespace ge { // state for graph task in life cycle enum GraphNodeState { GRAPH_NODE_INIT = 0, GRAPH_NODE_READY, }; using GraphId = uint32_t; using ConstModelPtr = std::shared_ptr; using GeModelPtr = std::shared_ptr; using ConstGraphPtr = std::shared_ptr; using GraphPtr = std::shared_ptr; const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL; struct ModelIdInfo { uint32_t model_id{INVALID_MODEL_ID}; }; class SubGraphInfo { public: SubGraphInfo(); ~SubGraphInfo(); void SetSubGraph(const ComputeGraphPtr &sub_graph_ptr) { subgraph_ptr_ = sub_graph_ptr; } ComputeGraphPtr GetSubGraph() const { return subgraph_ptr_; } void SetEngineName(const std::string &engine_name) { engine_name_ = engine_name; } const std::string &GetEngineName() const { return engine_name_; } void SetInputFlag(const std::vector &input_flag) { input_flag_ = input_flag; } const std::vector &GetInputFlag() const { return input_flag_; } void SetOutputFlag(const std::vector &output_flag) { output_flag_ = output_flag; } const std::vector &GetOutputFlag() const { return output_flag_; } void SetModelIdInfo(const ModelIdInfo &model_id_info) { model_id_info_ = model_id_info; } ModelIdInfo GetModelIdInfo() const { return model_id_info_; } void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; } bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; } Status FreeInOutBuffer(); void SetOutputContext(const std::string &output) { output_names_ = output; } std::string GetOutputContext() const { return output_names_; } void SetStreamLabel(const std::string &stream_label) { stream_label_ = stream_label; } const std::string &GetStreamLabel() const { return stream_label_; } void SetEnd2PldMap(std::unordered_map &end_map) { end_to_pld_ = end_map; } const std::unordered_map &GetEnd2PldMap() const { return end_to_pld_; } void SetPld2EndMap(std::unordered_map &pld_map) { pld_to_end_ = pld_map; } const std::unordered_map &GetPld2EndMap() const { return pld_to_end_; } private: ComputeGraphPtr subgraph_ptr_; std::string engine_name_; std::vector input_flag_; std::vector output_flag_; ModelIdInfo model_id_info_; GeModelPtr ge_model_ptr_; bool malloc_flag_; std::vector buffer_addr_; std::string output_names_; std::vector buffer_size_; std::string stream_label_; std::unordered_map end_to_pld_; std::unordered_map pld_to_end_; }; using SubGraphInfoPtr = std::shared_ptr; using Graph2SubGraphInfoList = std::unordered_map>; using Graph2InputNodesSubGraphInfo = std::unordered_map; // for run graph async listener class RunAsyncListener : public ge::ModelListener { public: RunAsyncListener() : sem_(1) {} ~RunAsyncListener() = default; void SetCallback(const RunAsyncCallback &callback); // callback Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result, std::vector &outputs) override; private: RunAsyncCallback callback_; BlockingQueue sem_; }; // single graph node info class GraphNode { public: explicit GraphNode(GraphId graph_id); ~GraphNode(); GraphId GetGraphId() const { return graph_id_; } ConstGraphPtr GetGraph() const { return graph_; } void SetGraph(const GraphPtr &graph) { graph_ = graph; } ComputeGraphPtr GetComputeGraph() const { return compute_graph_; } void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; } bool GetRunFlag() const { return run_flag_; } void SetRunFlag(bool flag) { run_flag_ = flag; } bool IsAsync() const { return async_; } void SetAsync(bool flag) { async_ = flag; } void SetSubGraph(std::vector &subgraph_ptr_list) { subgraph_ptr_list_ = subgraph_ptr_list; } const std::vector &GetAllSubGraph() const { return subgraph_ptr_list_; } bool GetBuildFlag() const { return build_flag_; } void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; } bool GetLoadFlag() const { return load_flag_; } void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } GeModelPtr GetGeModel() const { return ge_model_; } void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; } GeRootModelPtr GetGeRootModel() const { return ge_root_model_; } const std::map& GetOptions() const { return options_; } void SetOptions(const std::map &options) { options_ = options; } void Lock(); void Unlock(); // run graph asynchronous listener std::shared_ptr graph_run_async_listener_; private: GraphId graph_id_; std::map options_; bool run_flag_; std::vector subgraph_ptr_list_; GraphPtr graph_; ComputeGraphPtr compute_graph_; bool build_flag_; bool load_flag_; bool async_; GeModelPtr ge_model_; GeRootModelPtr ge_root_model_; BlockingQueue sem_; }; using GraphNodePtr = std::shared_ptr; using ConstGraphNodePtr = shared_ptr; class GraphModelListener : public ge::ModelListener { public: GraphModelListener(std::mutex &mutex, std::condition_variable &cond); ~GraphModelListener() = default; // callback Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result, std::vector &outputs) override; Status ResetResult(); // need lock by caller uint32_t GetResultCode() const; bool IsFinished() const { return is_finished_; } private: uint32_t result_code_; bool is_finished_; // not owner std::mutex &mutex_; // not owner std::condition_variable &condition_; }; struct GraphManagerOptions { int32_t stream_num; int32_t perf_level; int32_t encrypt_mode; int32_t framework_type; std::string ek_file; std::string cert_file; std::string hw_key_file; std::string private_key_file; std::string calibration_conf_file; std::string insert_op_file; std::string output_node_name; std::string func_bin_path; std::string input_nodes_set_fp16; std::string core_type; bool compress_flag; bool run_graph_flag; bool train_graph_flag; bool local_fmk_op_flag; bool hcom_parallel; bool enable_print_op_pass; bool is_single_op; std::map stream_max_parallel_num; std::string output_datatype; std::string original_model_file; std::string save_original_model; std::string build_mode; std::string build_step; std::string tuning_path; std::string input_shape; std::string dynamic_dims; int32_t dynamic_node_type = -1; GraphManagerOptions() : stream_num(1), perf_level(domi::GEN_TASK_WITHOUT_FUSION), encrypt_mode(-1), framework_type(domi::TENSORFLOW), ek_file(""), cert_file(""), hw_key_file(""), private_key_file(""), calibration_conf_file(""), insert_op_file(""), output_node_name(""), func_bin_path(""), core_type(""), compress_flag(false), run_graph_flag(false), train_graph_flag(false), local_fmk_op_flag(false), hcom_parallel(false), enable_print_op_pass(true), is_single_op(false), save_original_model("false"), build_mode(""), build_step(""), tuning_path(""){} }; } // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_