|
|
|
@ -17,9 +17,15 @@
|
|
|
|
|
#include "graph/load/model_manager/task_info/label_goto_ex_task_info.h"
|
|
|
|
|
|
|
|
|
|
#include "graph/load/model_manager/davinci_model.h"
|
|
|
|
|
#include "graph/debug/ge_attr_define.h"
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
|
constexpr uint8_t kGotoBranchMax = 1;
|
|
|
|
|
|
|
|
|
|
LabelGotoExTaskInfo::~LabelGotoExTaskInfo() {
|
|
|
|
|
GE_FREE_RT_LOG(args_);
|
|
|
|
|
GE_FREE_RT_LOG(index_value_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
|
|
|
|
|
GELOGI("LabelGotoExTaskInfo Init Start.");
|
|
|
|
|
GE_CHECK_NOTNULL(davinci_model);
|
|
|
|
@ -28,7 +34,7 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get LabelGoto task def
|
|
|
|
|
// Get LabelGotoEx task def
|
|
|
|
|
const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex();
|
|
|
|
|
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index());
|
|
|
|
|
if (op_desc == nullptr) {
|
|
|
|
@ -48,15 +54,51 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da
|
|
|
|
|
GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size());
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
}
|
|
|
|
|
label_ = label_list[label_index];
|
|
|
|
|
GE_CHECK_NOTNULL(label_list[label_index]);
|
|
|
|
|
vector<rtLabel_t> label_used = { label_list[label_index] };
|
|
|
|
|
|
|
|
|
|
rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM;
|
|
|
|
|
GELOGI("memory_type: %u", memory_type);
|
|
|
|
|
args_size_ = kGotoBranchMax * sizeof(rtLabelDevInfo);
|
|
|
|
|
rtError_t rt_ret = rtMalloc(&args_, args_size_, memory_type);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
|
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret);
|
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
|
GELOGE(RT_FAILED, "Call rtLabelListCpy failed, error: %#x", rt_ret);
|
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_);
|
|
|
|
|
rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), memory_type);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
|
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret);
|
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint64_t branch_index = 0;
|
|
|
|
|
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
|
GELOGE(RT_FAILED, "Call rtMemcpy failed, error: %#x", rt_ret);
|
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_list[label_index]);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status LabelGotoExTaskInfo::Distribute() {
|
|
|
|
|
GELOGI("LabelGotoExTaskInfo Distribute Start.");
|
|
|
|
|
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
|
|
|
|
|
GE_CHECK_NOTNULL(args_);
|
|
|
|
|
GE_CHECK_NOTNULL(index_value_);
|
|
|
|
|
if (args_size_ == 0) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "branch max: %u, args size: %u invalid.", kGotoBranchMax, args_size_);
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, kGotoBranchMax, args_, stream_);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
|
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
|
|
|
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
|
|
|