You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/hybrid/executor/hybrid_model_pipeline_execu...

89 lines
2.3 KiB

#ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_
#define GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_
#include "common/blocking_queue.h"
#include "common/thread_pool.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/rt_callback_manager.h"
#include "hybrid/executor/subgraph_executor.h"
#include "hybrid_model_executor.h"
namespace ge {
namespace hybrid {
struct PipeExecutionConfig {
uint32_t device_id;
rtContext_t rt_context;
int num_executors;
int num_stages;
long iteration_end;
};
class StageExecutor {
public:
struct StageTask {
rtEvent_t event = nullptr;
int stage = 0;
long iteration = 0;
};
StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config);
~StageExecutor();
Status Init();
void Reset();
Status Start(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc,
int loop_count);
Status SetInputs(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc);
Status ExecuteAsync(const StageTask &args);
Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc);
Status Synchronize();
void SetNext(StageExecutor *next_executor) { next_executor_ = next_executor; }
private:
friend class HybridModelPipelineExecutor;
static Status ResetExecutionContext(GraphExecutionContext &context);
Status InitExecutionContext();
int id_;
HybridModel *model_;
PipeExecutionConfig *pipe_config_;
BlockingQueue<StageTask> task_queue_;
std::unique_ptr<SubgraphExecutor> root_graph_executor_;
GraphExecutionContext context_;
StageExecutor *next_executor_;
rtStream_t stream_ = nullptr;
};
class HybridModelPipelineExecutor {
public:
HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id);
~HybridModelPipelineExecutor();
Status Init();
Status InitStageExecutors();
Status Execute(HybridModelExecutor::ExecuteArgs &args);
private:
HybridModel *model_;
uint32_t device_id_;
std::vector<std::unique_ptr<StageExecutor>> stage_executors_;
PipeExecutionConfig config_;
GraphExecutionContext context_;
long iteration_ = 0;
};
} // namespace hybrid
} // namespace ge
#endif // GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_