|
|
|
|
@ -29,7 +29,7 @@ static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if PADDLE_WITH_ASCEND_CL
|
|
|
|
|
static void StreamCallbackFunc(void *user_data)
|
|
|
|
|
static void StreamCallbackFunc(void *user_data)
|
|
|
|
|
#endif
|
|
|
|
|
{
|
|
|
|
|
std::unique_ptr<std::function<void()>> func(
|
|
|
|
|
@ -42,7 +42,8 @@ StreamCallbackManager<Stream>::StreamCallbackManager(const Stream stream)
|
|
|
|
|
: stream_(stream), thread_pool_(1) {}
|
|
|
|
|
|
|
|
|
|
template <typename Stream>
|
|
|
|
|
void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback) const {
|
|
|
|
|
void StreamCallbackManager<Stream>::AddCallback(
|
|
|
|
|
std::function<void()> callback) const {
|
|
|
|
|
auto *callback_func = new std::function<void()>(std::move(callback));
|
|
|
|
|
auto *func = new std::function<void()>([this, callback_func] {
|
|
|
|
|
std::lock_guard<std::mutex> lock(mtx_);
|
|
|
|
|
@ -62,6 +63,8 @@ void StreamCallbackManager<Stream>::AddCallback(std::function<void()> callback)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if PADDLE_WITH_ASCEND_CL
|
|
|
|
|
VLOG(3) << "aclrtLaunchCallback at stream: " << stream_;
|
|
|
|
|
// TODO(zhiqiu): failed to call aclrtLaunchCallback
|
|
|
|
|
PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func,
|
|
|
|
|
ACL_CALLBACK_BLOCK, stream_));
|
|
|
|
|
#endif
|
|
|
|
|
|