/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GRAPH_EXECUTE_GRAPH_EXECUTE_H_ #define GE_GRAPH_EXECUTE_GRAPH_EXECUTE_H_ #include #include #include #include #include #include "common/debug/log.h" #include "common/debug/memory_dumper.h" #include "common/ge_types.h" #include "common/properties_manager.h" #include "common/string_util.h" #include "common/types.h" #include "common/util.h" #include "ge/ge_api_types.h" #include "graph/compute_graph.h" #include "graph/manager/graph_context.h" #include "graph/manager/graph_manager_utils.h" #include "graph/model.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" namespace ge { class GraphExecutor { public: GraphExecutor(); virtual ~GraphExecutor(); Status ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model, const std::vector &input_tensor, std::vector &output_tensor); ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, const std::vector &input_tensor); Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr listener); Status SetGraphContext(GraphContextPtr graph_context_ptr); static Status SetDynamicSize(uint32_t model_id, const std::vector &batch_num, int32_t dynamic_type); void SetTrainFlag(bool is_train_graph); const std::vector &GetOutputsDesc() const { return outputs_desc_; } Status FreeExecuteMemory(); static Status DataInput(const InputData &input_data, OutputData &output_data); static Status GetInputOutputDescInfo(const uint32_t model_id, vector &input_desc, vector &output_desc); static Status GetInputOutputDescInfo(const uint32_t model_id, vector &input_desc, vector &output_desc, std::vector &input_formats, std::vector &output_formats, bool new_model_desc = false); static Status GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info); static Status GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index); /// /// @ingroup ge /// @brief Get dynamic batch_info /// @param [in] model_id /// @param [out] batch_info /// @param [out] dynamic_type /// @return execute result /// static Status GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info, int32_t &dynamic_type); /// /// @ingroup ge /// @brief Get combined dynamic dims info /// @param [in] model_id /// @param [out] batch_info /// @return execute result /// static Status GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info); /// /// @ingroup ge /// @brief Get user designate shape order /// @param [in] model_id /// @param [out] user_input_shape_order /// @return execute result /// static Status GetUserDesignateShapeOrder(uint32_t model_id, std::vector &user_input_shape_order); static Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); static Status GetModelAttr(uint32_t model_id, std::vector &dynamic_output_shape_info); static Status GetInputOutputDescInfoForZeroCopy(uint32_t model_id, vector &input_desc, vector &output_desc, std::vector &input_formats, std::vector &output_formats); static Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); static Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, std::vector &output_dims); static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); private: Status PrepareInputData(const std::vector &input_tensor, InputData &graph_input_data, OutputData &graph_output_data, std::vector &output_desc); Status SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor); Status AsyncExecuteModel(uint32_t model_id, const std::vector &input_tensor); void InitModelIdInfo(std::vector &out_model_id_info, std::vector &sub_graph_vec, uint32_t output_size); Status FreeInOutBuffer(); Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); bool init_flag_; bool train_graph_flag_; // For run graph synchronous return std::mutex *sync_run_mutex_; std::condition_variable *condition_; // Run graph asynchronous call back listener std::shared_ptr graph_run_listener_; GraphContextPtr graph_context_; std::vector outputs_desc_; GraphId last_graph_id_; bool malloc_flag_; std::vector buffer_addr_; std::vector buffer_size_; }; } // namespace ge #endif // GE_GRAPH_EXECUTE_GRAPH_EXECUTE_H_