diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 0aac173e..fc861a24 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -2924,6 +2924,14 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { return SUCCESS; } +Status DavinciModel::CheckCapability(rtFeatureType_t featureType, int32_t featureInfo, bool &is_support) const { + int64_t value = RT_CAPABILITY_SUPPORT; + auto rt_ret = rtGetRtCapability(featureType, featureInfo, &value); + GE_CHK_BOOL_RET_STATUS(rt_ret == RT_ERROR_NONE, FAILED, "call rtGetRtCapability failed!"); + is_support = (value == RT_CAPABILITY_SUPPORT) ? true : false; + return SUCCESS; +} + Status DavinciModel::MallocKnownArgs() { GELOGI("DavinciModel::MallocKnownArgs in"); const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); @@ -2944,8 +2952,12 @@ Status DavinciModel::MallocKnownArgs() { } rtError_t rt_ret; // malloc args memory + bool is_support = false; + GE_CHK_STATUS_RET_NOLOG(CheckCapability(FEATURE_TYPE_MEMORY, MEMORY_INFO_TS_4G_LIMITED, is_support)); + auto mem_type = is_support ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; + if (total_args_size_ != 0) { - rt_ret = rtMalloc(&args_, total_args_size_, RT_MEMORY_HBM); + rt_ret = rtMalloc(&args_, total_args_size_, mem_type); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); @@ -2953,7 +2965,7 @@ Status DavinciModel::MallocKnownArgs() { } // malloc dynamic and static hybrid memory if (total_hybrid_args_size_ != 0) { - rt_ret = rtMalloc(&hybrid_addrs_, total_hybrid_args_size_, RT_MEMORY_HBM); + rt_ret = rtMalloc(&hybrid_addrs_, total_hybrid_args_size_, mem_type); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); @@ -2962,7 +2974,7 @@ Status DavinciModel::MallocKnownArgs() { // malloc fixed addr memory, eg: rts op if (total_fixed_addr_size_ != 0) { GELOGI("Begin to allocate fixed addr."); - rt_ret = rtMalloc(&fixed_addrs_, total_fixed_addr_size_, RT_MEMORY_HBM); + rt_ret = rtMalloc(&fixed_addrs_, total_fixed_addr_size_, mem_type); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index 93f968ee..e4b73d7e 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -530,6 +530,7 @@ class DavinciModel { } void SetKnownNode(bool known_node) { known_node_ = known_node; } bool IsKnownNode() { return known_node_; } + Status CheckCapability(rtFeatureType_t featureType, int32_t featureInfo, bool &is_support) const; Status MallocKnownArgs(); Status UpdateKnownNodeArgs(const vector &inputs, const vector &outputs); Status CreateKnownZeroCopyMap(const vector &inputs, const vector &outputs); diff --git a/tests/ut/ge/graph/load/davinci_model_unittest.cc b/tests/ut/ge/graph/load/davinci_model_unittest.cc index 18cc622b..55f418d6 100644 --- a/tests/ut/ge/graph/load/davinci_model_unittest.cc +++ b/tests/ut/ge/graph/load/davinci_model_unittest.cc @@ -141,6 +141,12 @@ TEST_F(UtestDavinciModel, init_success) { ProfilingManager::Instance().is_load_profiling_ = false; } +TEST_F(UtestDavinciModel, CheckCapability) { + DavinciModel model(0, nullptr); + bool is_support = false; + (void)model.CheckCapability(FEATURE_TYPE_MEMORY, MEMORY_INFO_TS_4G_LIMITED, is_support); +} + TEST_F(UtestDavinciModel, init_data_op) { DavinciModel model(0, nullptr); model.ge_model_ = make_shared(); diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 49f6a3f6..b028a5f4 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -59,6 +59,7 @@ typedef enum tagRtAicpuDeployType { typedef enum tagRtFeatureType { FEATURE_TYPE_MEMCPY = 0, + FEATURE_TYPE_MEMORY = 1, FEATURE_TYPE_RSV } rtFeatureType_t; @@ -67,6 +68,10 @@ typedef enum tagMemcpyInfo { MEMCPY_INFO_RSV } rtMemcpyInfo_t; +typedef enum tagMemoryInfo { + MEMORY_INFO_TS_4G_LIMITED = 0, + MEMORY_INFO_RSV +} rtMemoryInfo_t; /** * @ingroup dvrt_dev * @brief get total device number.