cache support

pull/949/head
chenyemeng 4 years ago
parent 2fc8c77a01
commit a892b2bf90

@ -379,17 +379,24 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co
/// ///
Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset, Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset,
uint8_t *&var_addr) { uint8_t *&var_addr) {
if (ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset) == RT_MEMORY_RDMA_HBM) { rtMemType_t mem_type = ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset);
switch (mem_type) {
case RT_MEMORY_RDMA_HBM:
if (offset < 0) { if (offset < 0) {
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset)); GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset));
return PARAM_INVALID; return PARAM_INVALID;
} }
var_addr = reinterpret_cast<uint8_t *>(offset); var_addr = reinterpret_cast<uint8_t *>(offset);
GE_CHECK_NOTNULL(var_addr); break;
} else { case RT_MEMORY_HBM:
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base); VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base);
var_addr = model_param.var_base + offset - model_param.logic_var_base; var_addr = model_param.var_base + offset - model_param.logic_var_base;
break;
default:
GELOGE(PARAM_INVALID, "unsupported memory type %u", mem_type);
return PARAM_INVALID;
} }
GE_CHECK_NOTNULL(var_addr);
return SUCCESS; return SUCCESS;
} }

@ -212,7 +212,7 @@ rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
if (var_offset_map_.count(offset) > 0) { if (var_offset_map_.count(offset) > 0) {
return var_offset_map_[offset]; return var_offset_map_[offset];
} }
return RT_MEMORY_HBM; return RT_MEMORY_RESERVED;
} }
VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) { VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
@ -660,7 +660,7 @@ rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) { if (var_resource_ == nullptr) {
GELOGW("VarManager has not been init."); GELOGW("VarManager has not been init.");
return RT_MEMORY_HBM; return RT_MEMORY_RESERVED;
} }
return var_resource_->GetVarMemType(offset); return var_resource_->GetVarMemType(offset);
} }

Loading…
Cancel
Save