From 9ad993ae2c04954af885034917757d89bcb5bf2f Mon Sep 17 00:00:00 2001 From: caifubi Date: Thu, 11 Jun 2020 10:48:34 +0800 Subject: [PATCH] Add GetStreamIdList in ge_runtime --- inc/framework/ge_runtime/model_runner.h | 2 ++ src/ge/ge_runtime/model_runner.cc | 11 +++++++++++ src/ge/ge_runtime/runtime_model.cc | 2 ++ src/ge/ge_runtime/runtime_model.h | 2 ++ 4 files changed, 17 insertions(+) diff --git a/inc/framework/ge_runtime/model_runner.h b/inc/framework/ge_runtime/model_runner.h index 6e7abcb9..26e1b739 100644 --- a/inc/framework/ge_runtime/model_runner.h +++ b/inc/framework/ge_runtime/model_runner.h @@ -38,6 +38,8 @@ class ModelRunner { const std::vector &GetTaskIdList(uint32_t model_id) const; + const std::vector &GetStreamIdList(uint32_t model_id) const; + bool UnloadModel(uint32_t model_id); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); diff --git a/src/ge/ge_runtime/model_runner.cc b/src/ge/ge_runtime/model_runner.cc index 59952e39..0a7aea22 100644 --- a/src/ge/ge_runtime/model_runner.cc +++ b/src/ge/ge_runtime/model_runner.cc @@ -60,6 +60,17 @@ const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const return model_iter->second->GetTaskIdList(); } +const std::vector &ModelRunner::GetStreamIdList(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); + static const std::vector empty_ret; + return empty_ret; + } + + return model_iter->second->GetStreamIdList(); +} + bool ModelRunner::UnloadModel(uint32_t model_id) { auto iter = runtime_models_.find(model_id); if (iter != runtime_models_.end()) { diff --git a/src/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc index 5573fa89..0ef90035 100644 --- a/src/ge/ge_runtime/runtime_model.cc +++ b/src/ge/ge_runtime/runtime_model.cc @@ -207,6 +207,7 @@ bool RuntimeModel::LoadTask() { return false; } task_id_list_.push_back(task_id); + stream_id_list_.push_back(stream_id); } GELOGI("Distribute task succ."); @@ -486,5 +487,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp const std::vector &RuntimeModel::GetTaskIdList() const { return task_id_list_; } +const std::vector &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } } // namespace model_runner } // namespace ge diff --git a/src/ge/ge_runtime/runtime_model.h b/src/ge/ge_runtime/runtime_model.h index e8ff4057..fad3848b 100644 --- a/src/ge/ge_runtime/runtime_model.h +++ b/src/ge/ge_runtime/runtime_model.h @@ -36,6 +36,7 @@ class RuntimeModel { bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr &davinci_model); const std::vector &GetTaskIdList() const; + const std::vector &GetStreamIdList() const; bool Run(); bool CopyInputData(const InputData &input_data); bool GetInputOutputDescInfo(bool zero_copy, std::vector *input_desc, @@ -77,6 +78,7 @@ class RuntimeModel { std::vector> constant_info_list_{}; std::vector task_id_list_{}; + std::vector stream_id_list_{}; }; } // namespace model_runner