|
|
|
@ -379,17 +379,24 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co
|
|
|
|
|
///
|
|
|
|
|
Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset,
|
|
|
|
|
uint8_t *&var_addr) {
|
|
|
|
|
if (ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset) == RT_MEMORY_RDMA_HBM) {
|
|
|
|
|
if (offset < 0) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset));
|
|
|
|
|
rtMemType_t mem_type = ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset);
|
|
|
|
|
switch (mem_type) {
|
|
|
|
|
case RT_MEMORY_RDMA_HBM:
|
|
|
|
|
if (offset < 0) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset));
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
var_addr = reinterpret_cast<uint8_t *>(offset);
|
|
|
|
|
break;
|
|
|
|
|
case RT_MEMORY_HBM:
|
|
|
|
|
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base);
|
|
|
|
|
var_addr = model_param.var_base + offset - model_param.logic_var_base;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
GELOGE(PARAM_INVALID, "unsupported memory type %u", mem_type);
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
var_addr = reinterpret_cast<uint8_t *>(offset);
|
|
|
|
|
GE_CHECK_NOTNULL(var_addr);
|
|
|
|
|
} else {
|
|
|
|
|
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base);
|
|
|
|
|
var_addr = model_param.var_base + offset - model_param.logic_var_base;
|
|
|
|
|
}
|
|
|
|
|
GE_CHECK_NOTNULL(var_addr);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|