You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
144 lines
3.8 KiB
144 lines
3.8 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "tbe_handle_store.h"
|
|
|
|
#include <limits>
|
|
#include "common/ge_inner_error_codes.h"
|
|
#include "framework/common/debug/ge_log.h"
|
|
#include "runtime/kernel.h"
|
|
|
|
namespace ge {
|
|
void TbeHandleInfo::used_inc(uint32_t num) {
|
|
if (used_ > std::numeric_limits<uint32_t>::max() - num) {
|
|
GELOGE(INTERNAL_ERROR, "Used[%u] reach numeric max.", used_);
|
|
return;
|
|
}
|
|
|
|
used_ += num;
|
|
}
|
|
|
|
void TbeHandleInfo::used_dec(uint32_t num) {
|
|
if (used_ < std::numeric_limits<uint32_t>::min() + num) {
|
|
GELOGE(INTERNAL_ERROR, "Used[%u] reach numeric min.", used_);
|
|
return;
|
|
}
|
|
|
|
used_ -= num;
|
|
}
|
|
|
|
uint32_t TbeHandleInfo::used_num() const {
|
|
return used_;
|
|
}
|
|
|
|
void *TbeHandleInfo::handle() const {
|
|
return handle_;
|
|
}
|
|
|
|
|
|
TBEHandleStore &TBEHandleStore::GetInstance() {
|
|
static TBEHandleStore instance;
|
|
|
|
return instance;
|
|
}
|
|
|
|
///
|
|
/// @ingroup ge
|
|
/// @brief Find Registered TBE handle by name.
|
|
/// @param [in] name: TBE handle name to find.
|
|
/// @param [out] handle: handle names record.
|
|
/// @return true: found / false: not found.
|
|
///
|
|
bool TBEHandleStore::FindTBEHandle(const std::string &name, void *&handle) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
auto it = kernels_.find(name);
|
|
if (it == kernels_.end()) {
|
|
return false;
|
|
} else {
|
|
TbeHandleInfo &info = it->second;
|
|
handle = info.handle();
|
|
return true;
|
|
}
|
|
}
|
|
|
|
///
|
|
/// @ingroup ge
|
|
/// @brief Store registered TBE handle info.
|
|
/// @param [in] name: TBE handle name to store.
|
|
/// @param [in] handle: TBE handle addr to store.
|
|
/// @param [in] kernel: TBE kernel bin to store.
|
|
/// @return NA
|
|
///
|
|
void TBEHandleStore::StoreTBEHandle(const std::string &name, void *handle,
|
|
std::shared_ptr<OpKernelBin> &kernel) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
auto it = kernels_.find(name);
|
|
if (it == kernels_.end()) {
|
|
TbeHandleInfo info(handle, kernel);
|
|
info.used_inc();
|
|
kernels_.emplace(name, info);
|
|
} else {
|
|
TbeHandleInfo &info = it->second;
|
|
info.used_inc();
|
|
}
|
|
}
|
|
|
|
///
|
|
/// @ingroup ge
|
|
/// @brief Increase reference of registered TBE handle info.
|
|
/// @param [in] name: handle name increase reference.
|
|
/// @return NA
|
|
///
|
|
void TBEHandleStore::ReferTBEHandle(const std::string &name) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
auto it = kernels_.find(name);
|
|
if (it == kernels_.end()) {
|
|
GELOGE(INTERNAL_ERROR, "Kernel[%s] not found in stored.", name.c_str());
|
|
return;
|
|
}
|
|
|
|
TbeHandleInfo &info = it->second;
|
|
info.used_inc();
|
|
}
|
|
|
|
///
|
|
/// @ingroup ge
|
|
/// @brief Erase TBE registered handle record.
|
|
/// @param [in] names: handle names erase.
|
|
/// @return NA
|
|
///
|
|
void TBEHandleStore::EraseTBEHandle(const std::map<std::string, uint32_t> &names) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
for (auto &item : names) {
|
|
auto it = kernels_.find(item.first);
|
|
if (it == kernels_.end()) {
|
|
GELOGE(INTERNAL_ERROR, "Kernel[%s] not found in stored.", item.first.c_str());
|
|
continue;
|
|
}
|
|
|
|
TbeHandleInfo &info = it->second;
|
|
if (info.used_num() > item.second) {
|
|
info.used_dec(item.second);
|
|
} else {
|
|
rtError_t rt_ret = rtDevBinaryUnRegister(info.handle());
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
GELOGE(INTERNAL_ERROR, "Kernel[%s] UnRegister handle fail:%u.", item.first.c_str(), rt_ret);
|
|
}
|
|
kernels_.erase(it);
|
|
}
|
|
}
|
|
}
|
|
} // namespace ge
|