Add GetStreamIdList in ge_runtime

pull/37/head
caifubi 5 years ago
parent 5369646b48
commit 9ad993ae2c

@ -38,6 +38,8 @@ class ModelRunner {
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;
bool UnloadModel(uint32_t model_id);
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);

@ -60,6 +60,17 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return model_iter->second->GetTaskIdList();
}
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}
return model_iter->second->GetStreamIdList();
}
bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) {

@ -207,6 +207,7 @@ bool RuntimeModel::LoadTask() {
return false;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
}
GELOGI("Distribute task succ.");
@ -486,5 +487,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace model_runner
} // namespace ge

@ -36,6 +36,7 @@ class RuntimeModel {
bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
bool Run();
bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
@ -77,6 +78,7 @@ class RuntimeModel {
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{};
std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
};
} // namespace model_runner

Loading…
Cancel
Save