!1064 dynamic shape supprot pipeline
From: @isaacxr Reviewed-by: @sheng-nan,@xchu42 Signed-off-by: @ji_chenpull/1064/MERGE
commit
dfd119571e
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,88 @@
|
|||||||
|
#ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_
|
||||||
|
#define GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_
|
||||||
|
|
||||||
|
#include "common/blocking_queue.h"
|
||||||
|
#include "common/thread_pool.h"
|
||||||
|
#include "hybrid/executor/hybrid_execution_context.h"
|
||||||
|
#include "hybrid/executor/rt_callback_manager.h"
|
||||||
|
#include "hybrid/executor/subgraph_executor.h"
|
||||||
|
#include "hybrid_model_executor.h"
|
||||||
|
|
||||||
|
namespace ge {
|
||||||
|
namespace hybrid {
|
||||||
|
|
||||||
|
struct PipeExecutionConfig {
|
||||||
|
uint32_t device_id;
|
||||||
|
rtContext_t rt_context;
|
||||||
|
int num_executors;
|
||||||
|
int num_stages;
|
||||||
|
long iteration_end;
|
||||||
|
};
|
||||||
|
|
||||||
|
class StageExecutor {
|
||||||
|
public:
|
||||||
|
struct StageTask {
|
||||||
|
rtEvent_t event = nullptr;
|
||||||
|
int stage = 0;
|
||||||
|
long iteration = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config);
|
||||||
|
|
||||||
|
~StageExecutor();
|
||||||
|
|
||||||
|
Status Init();
|
||||||
|
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
Status Start(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc,
|
||||||
|
int loop_count);
|
||||||
|
|
||||||
|
Status SetInputs(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc);
|
||||||
|
|
||||||
|
Status ExecuteAsync(const StageTask &args);
|
||||||
|
|
||||||
|
Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc);
|
||||||
|
|
||||||
|
Status Synchronize();
|
||||||
|
|
||||||
|
void SetNext(StageExecutor *next_executor) { next_executor_ = next_executor; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class HybridModelPipelineExecutor;
|
||||||
|
static Status ResetExecutionContext(GraphExecutionContext &context);
|
||||||
|
Status InitExecutionContext();
|
||||||
|
|
||||||
|
int id_;
|
||||||
|
HybridModel *model_;
|
||||||
|
|
||||||
|
PipeExecutionConfig *pipe_config_;
|
||||||
|
BlockingQueue<StageTask> task_queue_;
|
||||||
|
std::unique_ptr<SubgraphExecutor> root_graph_executor_;
|
||||||
|
GraphExecutionContext context_;
|
||||||
|
StageExecutor *next_executor_;
|
||||||
|
|
||||||
|
rtStream_t stream_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
class HybridModelPipelineExecutor {
|
||||||
|
public:
|
||||||
|
HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id);
|
||||||
|
~HybridModelPipelineExecutor();
|
||||||
|
Status Init();
|
||||||
|
Status InitStageExecutors();
|
||||||
|
Status Execute(HybridModelExecutor::ExecuteArgs &args);
|
||||||
|
|
||||||
|
private:
|
||||||
|
HybridModel *model_;
|
||||||
|
uint32_t device_id_;
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<StageExecutor>> stage_executors_;
|
||||||
|
PipeExecutionConfig config_;
|
||||||
|
GraphExecutionContext context_;
|
||||||
|
long iteration_ = 0;
|
||||||
|
};
|
||||||
|
} // namespace hybrid
|
||||||
|
} // namespace ge
|
||||||
|
|
||||||
|
#endif // GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue