#ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ #define GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ #include "common/blocking_queue.h" #include "common/thread_pool.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/rt_callback_manager.h" #include "hybrid/executor/subgraph_executor.h" #include "hybrid_model_executor.h" namespace ge { namespace hybrid { struct PipeExecutionConfig { uint32_t device_id; rtContext_t rt_context; int num_executors; int num_stages; long iteration_end; }; class StageExecutor { public: struct StageTask { rtEvent_t event = nullptr; int stage = 0; long iteration = 0; }; StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config); ~StageExecutor(); Status Init(); void Reset(); Status Start(const std::vector &inputs, const std::vector &input_desc, int loop_count); Status SetInputs(const std::vector &inputs, const std::vector &input_desc); Status ExecuteAsync(const StageTask &args); Status GetOutputs(std::vector &outputs, std::vector &output_desc); Status Synchronize(); void SetNext(StageExecutor *next_executor) { next_executor_ = next_executor; } private: friend class HybridModelPipelineExecutor; static Status ResetExecutionContext(GraphExecutionContext &context); Status InitExecutionContext(); int id_; HybridModel *model_; PipeExecutionConfig *pipe_config_; BlockingQueue task_queue_; std::unique_ptr root_graph_executor_; GraphExecutionContext context_; StageExecutor *next_executor_; rtStream_t stream_ = nullptr; }; class HybridModelPipelineExecutor { public: HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id); ~HybridModelPipelineExecutor(); Status Init(); Status InitStageExecutors(); Status Execute(HybridModelExecutor::ExecuteArgs &args); private: HybridModel *model_; uint32_t device_id_; std::vector> stage_executors_; PipeExecutionConfig config_; GraphExecutionContext context_; long iteration_ = 0; }; } // namespace hybrid } // namespace ge #endif // GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_