parent
400cf19f14
commit
096673f675
@ -0,0 +1,117 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
EagerDeletionOpHandle::EagerDeletionOpHandle(
|
||||
ir::Node *node, const Scope *scope, const platform::Place &place,
|
||||
const std::vector<std::string> &var_names, GarbageCollector<Tensor> *gc,
|
||||
AtomicReferenceCountMap *ref_cnts)
|
||||
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (platform::is_gpu_place(place)) {
|
||||
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
|
||||
platform::DeviceContextPool::Instance().Get(place));
|
||||
if (dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_)) {
|
||||
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place).device);
|
||||
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (auto &name : var_names) AddVar(name);
|
||||
}
|
||||
|
||||
EagerDeletionOpHandle::~EagerDeletionOpHandle() {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (event_) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
|
||||
platform::SetDeviceId(gpu_place.device);
|
||||
PADDLE_ENFORCE(cudaEventDestroy(event_));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
|
||||
|
||||
void EagerDeletionOpHandle::AddVar(const std::string &name) {
|
||||
var_names_.insert(name);
|
||||
}
|
||||
|
||||
void EagerDeletionOpHandle::RunImpl() {
|
||||
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
|
||||
std::vector<Tensor *> tensors;
|
||||
for (auto &name : var_names_) {
|
||||
auto it = ref_cnts_->find(name);
|
||||
if (it == ref_cnts_->end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *var = exec_scope->FindVar(name);
|
||||
if (var == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (var->IsType<LoDTensor>()) {
|
||||
if (it->second.fetch_sub(1) == 1) {
|
||||
tensors.emplace_back(var->GetMutable<LoDTensor>());
|
||||
}
|
||||
} else if (var->IsType<SelectedRows>()) {
|
||||
if (it->second.fetch_sub(1) == 1) {
|
||||
tensors.emplace_back(var->GetMutable<SelectedRows>()->mutable_value());
|
||||
}
|
||||
} else if (var->IsType<LoDTensorArray>()) {
|
||||
if (it->second.fetch_sub(1) == 1) {
|
||||
auto *tensor_arr = var->GetMutable<LoDTensorArray>();
|
||||
for (auto &t : *tensor_arr) {
|
||||
tensors.emplace_back(&t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!tensors.empty()) {
|
||||
ClearTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void EagerDeletionOpHandle::ClearTensors(const std::vector<Tensor *> &tensors) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (event_) {
|
||||
auto compute_stream = dev_ctx_->stream();
|
||||
auto callback_stream =
|
||||
static_cast<StreamGarbageCollector<Tensor> *>(gc_)->stream();
|
||||
auto callback_func = [=]() {
|
||||
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
|
||||
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
|
||||
};
|
||||
gc_->Add(tensors, callback_func);
|
||||
} else {
|
||||
#endif
|
||||
gc_->Add(tensors);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class Scope;
|
||||
|
||||
namespace details {
|
||||
|
||||
class EagerDeletionPass;
|
||||
|
||||
class EagerDeletionOpHandle : public OpHandleBase {
|
||||
public:
|
||||
EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
|
||||
const platform::Place &place,
|
||||
const std::vector<std::string> &var_names,
|
||||
GarbageCollector<Tensor> *gc,
|
||||
AtomicReferenceCountMap *ref_cnts);
|
||||
|
||||
~EagerDeletionOpHandle();
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
|
||||
private:
|
||||
void ClearTensors(const std::vector<Tensor *> &tensors);
|
||||
|
||||
void AddVar(const std::string &name);
|
||||
|
||||
const Scope *scope_;
|
||||
std::unordered_set<std::string> var_names_;
|
||||
GarbageCollector<Tensor> *gc_; // not own
|
||||
AtomicReferenceCountMap *ref_cnts_; // not own
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::CUDADeviceContext *dev_ctx_{nullptr};
|
||||
cudaEvent_t event_{nullptr};
|
||||
#endif
|
||||
|
||||
friend class EagerDeletionPass;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,96 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
|
||||
ir::Graph *graph) {
|
||||
auto it = std::find_if(
|
||||
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
|
||||
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
|
||||
});
|
||||
|
||||
if (it != in->Outputs().end()) {
|
||||
out->AddInput(*it);
|
||||
} else {
|
||||
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
|
||||
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
||||
in->AddOutput(dep_var);
|
||||
out->AddInput(dep_var);
|
||||
}
|
||||
|
||||
// Add leaf node to eager_deletion_node
|
||||
if (out->Outputs().empty()) {
|
||||
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
|
||||
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
|
||||
out->AddOutput(dummy_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
auto &vars = graph->Get<GraphVars>(kGraphVars);
|
||||
|
||||
auto &ref_cnts =
|
||||
Get<std::vector<AtomicReferenceCountMap>>(kCurReferenceCount);
|
||||
auto &last_live_ops = Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
|
||||
auto &gcs = Get<GarbageCollectorList>(kGarbageCollector);
|
||||
|
||||
ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size());
|
||||
|
||||
std::unordered_map<ComputationOpHandle *, EagerDeletionOpHandle *> op_map;
|
||||
for (auto &var_ops_map : last_live_ops) {
|
||||
for (auto &var_ops_pair : var_ops_map) {
|
||||
const std::string &var_name = var_ops_pair.first;
|
||||
for (ComputationOpHandle *op : var_ops_pair.second) {
|
||||
auto it = op_map.find(op);
|
||||
if (it != op_map.end()) {
|
||||
it->second->AddVar(var_name);
|
||||
} else {
|
||||
auto *eager_deletion_node = graph->CreateEmptyNode(
|
||||
"eager_deletion", ir::Node::Type::kOperation);
|
||||
auto *eager_deletion_op = new EagerDeletionOpHandle(
|
||||
eager_deletion_node, op->GetScope(), op->GetPlace(), {var_name},
|
||||
gcs[op->GetScopeIdx()].get(), &(ref_cnts[op->GetScopeIdx()]));
|
||||
AddDependencyBetween(op, eager_deletion_op, graph.get());
|
||||
op_map[op] = eager_deletion_op;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(10) << "Create " << op_map.size() << " EagerDeletionOpHandle(s)";
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(eager_deletion_pass,
|
||||
paddle::framework::details::EagerDeletionPass)
|
||||
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
|
||||
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
|
||||
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
|
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class EagerDeletionPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,138 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/garbage_collector.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
using ReferenceCountMap = std::unordered_map<std::string, int>;
|
||||
using AtomicReferenceCountMap =
|
||||
std::unordered_map<std::string, std::atomic<int>>;
|
||||
using DeviceReferenceCountMap =
|
||||
std::unordered_map<int, std::unique_ptr<ReferenceCountMap>>;
|
||||
using AtomicDeviceReferenceCountMap =
|
||||
std::unordered_map<int, std::unique_ptr<AtomicReferenceCountMap>>;
|
||||
using DeviceGarbageCollectorMap =
|
||||
std::unordered_map<int,
|
||||
std::unique_ptr<GarbageCollector<framework::Tensor>>>;
|
||||
|
||||
class ReferenceCountOpHandle : public OpHandleBase {
|
||||
public:
|
||||
ReferenceCountOpHandle(ir::Node *node, const Scope *scope,
|
||||
const platform::CUDAPlace &place,
|
||||
const std::vector<std::string> &var_names,
|
||||
GarbageCollector<Tensor> *gc,
|
||||
AtomicReferenceCountMap *ref_cnts)
|
||||
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
|
||||
dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
|
||||
platform::DeviceContextPool::Instance().Get(place));
|
||||
if (IsStreamGarabageCollector()) {
|
||||
platform::SetDeviceId(place.device);
|
||||
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
for (auto &name : var_names) AddVar(name);
|
||||
}
|
||||
|
||||
~ReferenceCountOpHandle() {
|
||||
if (IsStreamGarabageCollector()) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
|
||||
platform::SetDeviceId(gpu_place.device);
|
||||
PADDLE_ENFORCE(cudaEventDestroy(event_));
|
||||
}
|
||||
}
|
||||
|
||||
std::string Name() const override { return "reference_count"; }
|
||||
|
||||
void AddVar(const std::string &name) {
|
||||
auto it = var_names_.find(name);
|
||||
if (it != var_names_.end())
|
||||
++(it->second);
|
||||
else
|
||||
var_names_[name] = 1;
|
||||
}
|
||||
|
||||
protected:
|
||||
void RunImpl() override {
|
||||
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
|
||||
std::vector<Tensor *> tensors;
|
||||
for (auto &pair : var_names_) {
|
||||
auto &name = pair.first;
|
||||
auto it = ref_cnts_->find(name);
|
||||
if (it == ref_cnts_->end()) continue;
|
||||
|
||||
auto *var = exec_scope->FindVar(name);
|
||||
if (var == nullptr) continue;
|
||||
|
||||
if (var->IsType<LoDTensor>()) {
|
||||
if (it->second.fetch_sub(pair.second) <= pair.second) {
|
||||
tensors.emplace_back(var->GetMutable<LoDTensor>());
|
||||
}
|
||||
} else if (var->IsType<SelectedRows>()) {
|
||||
if (it->second.fetch_sub(pair.second) <= pair.second) {
|
||||
tensors.emplace_back(
|
||||
var->GetMutable<SelectedRows>()->mutable_value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!tensors.empty()) {
|
||||
ClearTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void ClearTensors(const std::vector<Tensor *> &tensors) {
|
||||
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
|
||||
if (gc != nullptr) {
|
||||
auto compute_stream = dev_ctx_->stream();
|
||||
auto callback_stream = gc->stream();
|
||||
auto callback_func = [=]() {
|
||||
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
|
||||
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
|
||||
};
|
||||
gc_->Add(tensors, callback_func);
|
||||
} else {
|
||||
gc_->Add(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsStreamGarabageCollector() const {
|
||||
return dynamic_cast<const StreamGarbageCollector<Tensor> *>(gc_) != nullptr;
|
||||
}
|
||||
|
||||
const Scope *scope_;
|
||||
platform::CUDADeviceContext *dev_ctx_;
|
||||
std::unordered_map<std::string, int> var_names_;
|
||||
GarbageCollector<Tensor> *gc_; // not own
|
||||
AtomicReferenceCountMap *ref_cnts_; // not own
|
||||
cudaEvent_t event_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,49 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/garbage_collector.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class ComputationOpHandle;
|
||||
|
||||
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
|
||||
|
||||
using AtomicReferenceCountMap =
|
||||
std::unordered_map<std::string, std::atomic<size_t>>;
|
||||
|
||||
using GarbageCollectorList =
|
||||
std::vector<std::unique_ptr<GarbageCollector<Tensor>>>;
|
||||
|
||||
const char kGlobalReferenceCount[] = "reference_count";
|
||||
const char kCurReferenceCount[] = "current_reference_count";
|
||||
const char kGarbageCollector[] = "garbage_collector";
|
||||
|
||||
using LastLiveOpsOfVars =
|
||||
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
|
||||
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/platform/stream_callback_manager.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
struct StreamCallbackContext {
|
||||
inline StreamCallbackContext(const StreamCallbackManager *manager,
|
||||
std::function<void()> callback)
|
||||
: manager_(manager), callback_(std::move(callback)) {}
|
||||
|
||||
const StreamCallbackManager *manager_; // do not own
|
||||
std::function<void()> callback_;
|
||||
};
|
||||
|
||||
StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream)
|
||||
: stream_(stream), thread_pool_(new ::ThreadPool(1)) {}
|
||||
|
||||
void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
|
||||
auto *stream_callback_context =
|
||||
new StreamCallbackContext(this, std::move(callback));
|
||||
#if CUDA_VERSION >= 10000
|
||||
PADDLE_ENFORCE(cudaLaunchHostFunc(stream_,
|
||||
StreamCallbackManager::StreamCallbackFunc,
|
||||
stream_callback_context));
|
||||
#else
|
||||
PADDLE_ENFORCE(
|
||||
cudaStreamAddCallback(stream_, StreamCallbackManager::StreamCallbackFunc,
|
||||
stream_callback_context, 0));
|
||||
#endif
|
||||
}
|
||||
|
||||
void StreamCallbackManager::Wait() const {
|
||||
thread_pool_.reset(new ::ThreadPool(1));
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10000
|
||||
void CUDART_CB StreamCallbackManager::StreamCallbackFunc(void *user_data)
|
||||
#else
|
||||
void CUDART_CB StreamCallbackManager::StreamCallbackFunc(cudaStream_t stream,
|
||||
cudaError_t status,
|
||||
void *user_data)
|
||||
#endif
|
||||
{
|
||||
auto *callback_context_ptr =
|
||||
reinterpret_cast<StreamCallbackContext *>(user_data);
|
||||
callback_context_ptr->manager_->thread_pool_->enqueue(
|
||||
[callback_context_ptr]() {
|
||||
std::unique_ptr<StreamCallbackContext> callback_context(
|
||||
callback_context_ptr);
|
||||
callback_context->callback_();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
Loading…
Reference in new issue