|
|
@ -16,20 +16,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
#include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h"
|
|
|
|
#include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include "graph/debug/ge_attr_define.h"
|
|
|
|
|
|
|
|
#include "graph/load/model_manager/davinci_model.h"
|
|
|
|
#include "graph/load/model_manager/davinci_model.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
namespace ge {
|
|
|
|
constexpr uint8_t kLabelSwitchIndexNum = 1;
|
|
|
|
constexpr uint8_t kLabelSwitchIndexNum = 1;
|
|
|
|
|
|
|
|
|
|
|
|
LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() {
|
|
|
|
LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() {
|
|
|
|
if (args_ != nullptr) {
|
|
|
|
GE_FREE_RT_LOG(args_);
|
|
|
|
rtError_t ret = rtFree(args_);
|
|
|
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
|
|
|
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
args_ = nullptr;
|
|
|
|
|
|
|
|
index_value_ = nullptr;
|
|
|
|
index_value_ = nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -37,13 +30,12 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
|
|
|
|
GELOGI("LabelSwitchByIndexTaskInfo Init Start.");
|
|
|
|
GELOGI("LabelSwitchByIndexTaskInfo Init Start.");
|
|
|
|
GE_CHECK_NOTNULL(davinci_model);
|
|
|
|
GE_CHECK_NOTNULL(davinci_model);
|
|
|
|
|
|
|
|
|
|
|
|
const vector<rtLabel_t> &label_list = davinci_model->GetLabelList();
|
|
|
|
|
|
|
|
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList());
|
|
|
|
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList());
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Get LabelSwitch task def
|
|
|
|
// Get LabelSwitchByIndex task def
|
|
|
|
const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index();
|
|
|
|
const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index();
|
|
|
|
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index());
|
|
|
|
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index());
|
|
|
|
if (op_desc == nullptr) {
|
|
|
|
if (op_desc == nullptr) {
|
|
|
@ -68,7 +60,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
|
|
|
|
|
|
|
|
|
|
|
|
davinci_model->DisableZeroCopy(index_value_);
|
|
|
|
davinci_model->DisableZeroCopy(index_value_);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<uint32_t> label_idx_list;
|
|
|
|
vector<uint32_t> label_idx_list;
|
|
|
|
if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) {
|
|
|
|
if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) {
|
|
|
|
GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(),
|
|
|
|
GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(),
|
|
|
|
ATTR_NAME_LABEL_SWITCH_LIST.c_str());
|
|
|
|
ATTR_NAME_LABEL_SWITCH_LIST.c_str());
|
|
|
@ -81,7 +73,8 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
label_list_.resize(branch_max_, nullptr);
|
|
|
|
vector<rtLabel_t> label_used(branch_max_, nullptr);
|
|
|
|
|
|
|
|
const vector<rtLabel_t> &label_list = davinci_model->GetLabelList();
|
|
|
|
for (size_t idx = 0; idx < label_idx_list.size(); ++idx) {
|
|
|
|
for (size_t idx = 0; idx < label_idx_list.size(); ++idx) {
|
|
|
|
uint32_t label_id = label_idx_list[idx];
|
|
|
|
uint32_t label_id = label_idx_list[idx];
|
|
|
|
if (label_id >= label_list.size()) {
|
|
|
|
if (label_id >= label_list.size()) {
|
|
|
@ -90,8 +83,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
GE_CHECK_NOTNULL(label_list[label_id]);
|
|
|
|
GE_CHECK_NOTNULL(label_list[label_id]);
|
|
|
|
|
|
|
|
label_used[idx] = label_list[label_id];
|
|
|
|
label_list_[idx] = label_list[label_id];
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM;
|
|
|
|
rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM;
|
|
|
@ -103,7 +95,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_);
|
|
|
|
rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_);
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
|
|
|
|
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
@ -125,7 +117,7 @@ Status LabelSwitchByIndexTaskInfo::Distribute() {
|
|
|
|
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_);
|
|
|
|
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_);
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
|
|
|
|
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
|
|
|
|
return RT_FAILED;
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
GELOGI("LabelSwitchByIndexTaskInfo Distribute Success.");
|
|
|
|
GELOGI("LabelSwitchByIndexTaskInfo Distribute Success.");
|
|
|
|