You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/load/model_manager/tbe_handle_store.cc

144 lines
3.8 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tbe_handle_store.h"
#include <limits>
#include "common/ge_inner_error_codes.h"
#include "framework/common/debug/ge_log.h"
#include "runtime/kernel.h"
namespace ge {
void TbeHandleInfo::used_inc(uint32_t num) {
if (used_ > std::numeric_limits<uint32_t>::max() - num) {
GELOGE(INTERNAL_ERROR, "Used[%u] reach numeric max.", used_);
return;
}
used_ += num;
}
void TbeHandleInfo::used_dec(uint32_t num) {
if (used_ < std::numeric_limits<uint32_t>::min() + num) {
GELOGE(INTERNAL_ERROR, "Used[%u] reach numeric min.", used_);
return;
}
used_ -= num;
}
uint32_t TbeHandleInfo::used_num() const {
return used_;
}
void *TbeHandleInfo::handle() const {
return handle_;
}
TBEHandleStore &TBEHandleStore::GetInstance() {
static TBEHandleStore instance;
return instance;
}
///
/// @ingroup ge
/// @brief Find Registered TBE handle by name.
/// @param [in] name: TBE handle name to find.
/// @param [out] handle: handle names record.
/// @return true: found / false: not found.
///
bool TBEHandleStore::FindTBEHandle(const std::string &name, void *&handle) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = kernels_.find(name);
if (it == kernels_.end()) {
return false;
} else {
TbeHandleInfo &info = it->second;
handle = info.handle();
return true;
}
}
///
/// @ingroup ge
/// @brief Store registered TBE handle info.
/// @param [in] name: TBE handle name to store.
/// @param [in] handle: TBE handle addr to store.
/// @param [in] kernel: TBE kernel bin to store.
/// @return NA
///
void TBEHandleStore::StoreTBEHandle(const std::string &name, void *handle,
std::shared_ptr<OpKernelBin> &kernel) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = kernels_.find(name);
if (it == kernels_.end()) {
TbeHandleInfo info(handle, kernel);
info.used_inc();
kernels_.emplace(name, info);
} else {
TbeHandleInfo &info = it->second;
info.used_inc();
}
}
///
/// @ingroup ge
/// @brief Increase reference of registered TBE handle info.
/// @param [in] name: handle name increase reference.
/// @return NA
///
void TBEHandleStore::ReferTBEHandle(const std::string &name) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = kernels_.find(name);
if (it == kernels_.end()) {
GELOGE(INTERNAL_ERROR, "Kernel[%s] not found in stored.", name.c_str());
return;
}
TbeHandleInfo &info = it->second;
info.used_inc();
}
///
/// @ingroup ge
/// @brief Erase TBE registered handle record.
/// @param [in] names: handle names erase.
/// @return NA
///
void TBEHandleStore::EraseTBEHandle(const std::map<std::string, uint32_t> &names) {
std::lock_guard<std::mutex> lock(mutex_);
for (auto &item : names) {
auto it = kernels_.find(item.first);
if (it == kernels_.end()) {
GELOGE(INTERNAL_ERROR, "Kernel[%s] not found in stored.", item.first.c_str());
continue;
}
TbeHandleInfo &info = it->second;
if (info.used_num() > item.second) {
info.used_dec(item.second);
} else {
rtError_t rt_ret = rtDevBinaryUnRegister(info.handle());
if (rt_ret != RT_ERROR_NONE) {
GELOGE(INTERNAL_ERROR, "Kernel[%s] UnRegister handle fail:%u.", item.first.c_str(), rt_ret);
}
kernels_.erase(it);
}
}
}
} // namespace ge