You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
3.7 KiB
123 lines
3.7 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "hybrid/executor/rt_callback_manager.h"
|
|
#include "framework/common/ge_inner_error_codes.h"
|
|
#include "framework/common/debug/ge_log.h"
|
|
#include "framework/common/util.h"
|
|
|
|
namespace ge {
|
|
namespace hybrid {
|
|
CallbackManager::CallbackManager(rtStream_t stream) : stream_(stream) {
|
|
}
|
|
|
|
Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) {
|
|
GELOGD("To register callback");
|
|
rtEvent_t event = nullptr;
|
|
GE_CHK_RT_RET(rtEventCreate(&event));
|
|
auto rt_ret = rtEventRecord(event, stream_);
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
GELOGE(RT_FAILED, "Failed to invoke rtEventRecord, error code = %d", rt_ret);
|
|
(void) rtEventDestroy(event);
|
|
return RT_FAILED;
|
|
}
|
|
|
|
auto cb = std::pair<rtCallback_t, void *>(callback, user_data);
|
|
auto entry = std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>(event, std::move(cb));
|
|
if (!callback_queue_.Push(entry)) {
|
|
(void) rtEventDestroy(event);
|
|
return INTERNAL_ERROR;
|
|
}
|
|
|
|
GELOGD("Registering callback successfully");
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status CallbackManager::Init() {
|
|
rtContext_t ctx = nullptr;
|
|
GE_CHK_RT_RET(rtCtxGetCurrent(&ctx));
|
|
ret_future_ = std::async(std::launch::async, [&](rtContext_t context) ->Status {
|
|
return CallbackProcess(context);
|
|
}, ctx);
|
|
if (!ret_future_.valid()) {
|
|
GELOGE(INTERNAL_ERROR, "Failed to init callback manager.");
|
|
return INTERNAL_ERROR;
|
|
}
|
|
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status CallbackManager::CallbackProcess(rtContext_t context) {
|
|
GE_CHK_RT_RET(rtCtxSetCurrent(context));
|
|
std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> entry;
|
|
while (true) {
|
|
if (!callback_queue_.Pop(entry)) {
|
|
GELOGI("CallbackManager stopped");
|
|
return INTERNAL_ERROR;
|
|
}
|
|
|
|
auto event = entry.first;
|
|
if (event == nullptr) {
|
|
return SUCCESS;
|
|
}
|
|
|
|
auto rt_err = rtEventSynchronize(event);
|
|
if (rt_err != RT_ERROR_NONE) {
|
|
GELOGE(RT_FAILED, "rtEventSynchronize failed. ret = %d", rt_err);
|
|
GE_CHK_RT(rtEventDestroy(event));
|
|
return RT_FAILED;
|
|
}
|
|
|
|
GE_CHK_RT(rtEventDestroy(event));
|
|
|
|
auto cb_func = entry.second.first;
|
|
auto cb_args = entry.second.second;
|
|
cb_func(cb_args);
|
|
}
|
|
}
|
|
|
|
Status CallbackManager::Destroy() {
|
|
GELOGI("To destroy callback manager.");
|
|
if (!ret_future_.valid()) {
|
|
GELOGI("CallbackManager not initialized.");
|
|
return SUCCESS;
|
|
}
|
|
|
|
std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry;
|
|
eof_entry.first = nullptr;
|
|
callback_queue_.Push(eof_entry);
|
|
|
|
auto ret = ret_future_.get();
|
|
GELOGI("Callback manager ended. ret = %u", ret);
|
|
return ret;
|
|
}
|
|
|
|
void CallbackManager::RtCallbackFunc(void *data) {
|
|
GELOGD("To invoke callback function");
|
|
auto callback_func = reinterpret_cast<std::function<void()> *>(data);
|
|
(*callback_func)();
|
|
delete callback_func;
|
|
}
|
|
|
|
Status CallbackManager::RegisterCallback(const std::function<void()> &callback) {
|
|
auto func = std::unique_ptr<std::function<void()>>(new(std::nothrow) std::function<void()>(callback));
|
|
GE_CHECK_NOTNULL(func);
|
|
GELOGD("Callback registered");
|
|
return RegisterCallback(RtCallbackFunc, func.release());
|
|
}
|
|
} // namespace hybrid
|
|
} // namespace ge
|