[Speed]Refine ParallelExecutor (#16190)

* refine parallelExecutor
test=develop

* Polish op_handle
test=develop

* Remove unnecessary op_handle
test=develop

* Fix Travis CI
test=develop

* Fix fetch bug
test=develop

* Remove WaitInputVarGenerated

* Fix OpHandleBase::Run
test=develop

* debug
test=develop

* use origin fetch_op_handle
test=develop

* Revert op_handle_base.cc
test=develop

* Polish code
test=develop

* Fix OpHandleBase::Run
test=develop

* code refine

* test CI and CE
test=develop

* fix OpHandle::Run
test=develop

* refine AllReduceOpHandle
test=develop

* Polish code
test=develop
revert-16190-refine_parallel_executor
chengduo 6 years ago committed by GitHub
parent 33965527fd
commit a6a3b2fbbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,9 +51,7 @@ else()
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif()
cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
@ -74,7 +72,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)

@ -11,9 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
@ -56,6 +55,7 @@ void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(

@ -57,7 +57,7 @@ struct BroadcastOpHandle : public OpHandleBase {
std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; };
bool IsMultiDeviceTransfer() override { return true; };
protected:
void RunImpl() override;

@ -147,6 +147,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
if (VLOG_IS_ON(2)) {
AppendPass("all_reduce_deps_pass");
}
if (SeqOnlyAllReduceOps(strategy)) {
VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");

@ -1,154 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
namespace paddle {
namespace framework {
namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
DataBalanceOpHandle::DataBalanceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) {
for (auto &p : places_) {
this->SetDeviceContext(p, ctxs->DevCtx(p));
}
}
}
#else
DataBalanceOpHandle::DataBalanceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; }
std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
const std::vector<int> &device_sizes) {
int device_num = device_sizes.size();
int total_size = 0;
int empty_num = 0;
std::vector<std::array<int, 2>> size_device_vec;
size_device_vec.reserve(device_num);
for (int i = 0; i < device_num; ++i) {
if (device_sizes[i] == 0) {
++empty_num;
}
total_size += device_sizes[i];
size_device_vec.push_back({{device_sizes[i], i}});
}
std::vector<std::array<int, 3>> res;
if (empty_num == 0) {
// No need to do data balance.
return res;
}
if (total_size < device_num) {
// No enough data.
PADDLE_THROW_EOF();
}
std::sort(size_device_vec.begin(), size_device_vec.end(),
[](const std::array<int, 2> &a, const std::array<int, 2> &b) {
return a[0] > b[0];
});
int expected_device_size = total_size / device_num;
int src_idx = 0;
for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) {
if (size_device_vec[src_idx][0] <= expected_device_size) {
++src_idx;
PADDLE_ENFORCE_LT(
src_idx, device_num - empty_num,
"In current srategy an empty tensor should not be copy source.");
}
size_device_vec[src_idx][0] -= expected_device_size;
size_device_vec[dst_idx][0] += expected_device_size;
res.push_back({{size_device_vec[src_idx][1], size_device_vec[dst_idx][1],
expected_device_size}});
}
return res;
}
void DataBalanceOpHandle::RunImpl() {
PADDLE_ENFORCE_GT(places_.size(), 1UL,
"Data balance can only be enabled when the number of "
"places to run larger than 1.");
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
int data_num = in_var_handles.size() / places_.size();
WaitInputVarGenerated();
std::vector<std::vector<LoDTensor *>> lod_tensors(data_num);
std::vector<int> device_sizes;
for (int i = 0; i < static_cast<int>(in_var_handles.size()); ++i) {
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal.");
int place_idx = i / data_num;
int data_idx = i % data_num;
auto *local_scope =
local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name());
PADDLE_ENFORCE(tensor_var->IsType<LoDTensor>());
auto *tensor = tensor_var->GetMutable<LoDTensor>();
lod_tensors[data_idx].push_back(tensor);
int ins_size =
tensor->lod().empty() ? tensor->dims()[0] : tensor->NumElements();
if (data_idx == 0) {
device_sizes.emplace_back(ins_size);
} else {
PADDLE_ENFORCE_EQ(
ins_size, device_sizes.at(place_idx),
"All data on the same device shall have the same batch size.");
}
}
const auto &balance_plan = GetBalancePlan(device_sizes);
for (const auto &trans : balance_plan) {
for (int data_idx = 0; data_idx < data_num; ++data_idx) {
LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]];
LoDTensor *dst_tensor = lod_tensors[data_idx][trans[1]];
int trans_ins_size = trans[2];
LoD src_lod = src_tensor->lod();
int src_ins_size =
src_lod.empty() ? src_tensor->dims()[0] : src_tensor->NumElements();
int cut_point = src_ins_size - trans_ins_size;
if (!src_lod.empty()) {
for (auto &level : src_lod) {
cut_point = level[cut_point];
}
}
TensorCopySync(src_tensor->Slice(cut_point, src_tensor->dims()[0]),
dst_tensor->place(), dst_tensor);
src_tensor->ShareDataWith(src_tensor->Slice(0, cut_point));
if (!src_lod.empty()) {
dst_tensor->set_lod(SliceInLevel(
src_lod, 0, src_ins_size - trans_ins_size, src_ins_size));
src_tensor->set_lod(
SliceInLevel(src_lod, 0, 0, src_ins_size - trans_ins_size));
}
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle

@ -1,59 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
namespace details {
struct DataBalanceOpHandle : public OpHandleBase {
public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
// std::vector<(src_dev_id, dst_dev_id, trans_size)>
std::vector<std::array<int, 3>> GetBalancePlan(
const std::vector<int> &batch_size_per_device);
const std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> places_;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -82,6 +82,8 @@ void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
}
}
bool FetchOpHandle::IsMultiDeviceTransfer() { return true; }
std::string FetchOpHandle::Name() const { return "Fetch"; }
} // namespace details

@ -39,6 +39,8 @@ struct FetchOpHandle : public OpHandleBase {
std::string Name() const override;
bool IsMultiDeviceTransfer() override;
protected:
void RunImpl() override;

@ -1,51 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_vars_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
void FuseVarsOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(in_var_handles.size(), 0UL);
PADDLE_ENFORCE_EQ(out_var_handles.size() - 1, inputs_numel_.size(), "");
auto scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto out_var_handle = out_var_handles[0];
auto out_var = scope->Var(out_var_handle->name());
auto out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->Resize({total_numel_}).mutable_data(this->place_, type_);
int64_t s = 0;
for (size_t i = 1; i < out_var_handles.size(); ++i) {
auto out_name = out_var_handles[i]->name();
auto out_t = scope->Var(out_name)->GetMutable<LoDTensor>();
auto numel = this->inputs_numel_.at(out_name);
out_t->ShareDataWith(out_tensor->Slice(s, s + numel));
s += numel;
}
this->RunAndRecordEvent([] {});
}
std::string FuseVarsOpHandle::Name() const { return "fuse vars"; }
} // namespace details
} // namespace framework
} // namespace paddle

@ -1,65 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FuseVarsOpHandle : public OpHandleBase {
public:
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const proto::VarType::Type var_type)
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
inputs_numel_(inputs_numel),
type_(var_type) {
total_numel_ = 0;
for (auto in_numel : inputs_numel) {
PADDLE_ENFORCE_GT(in_numel.second, 0);
total_numel_ += in_numel.second;
}
}
std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
Scope *local_scope_;
const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_;
const proto::VarType::Type type_;
int64_t total_numel_;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -14,13 +14,15 @@
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"

@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/op_handle_base.h"
#include <map>
#include <unordered_set>
namespace paddle {
namespace framework {
@ -41,15 +42,42 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda) {
if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
}
if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) {
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
int dev_id =
boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
out_var_handle->SetGenerateEvent(events_[dev_id]);
}
}
} else {
PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL,
"%s should have only one dev_ctx.", Name());
auto &place = dev_ctxes_.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
PADDLE_ENFORCE(
platform::is_same_place(place, out_var_handle->place()),
"The place of input(%s) is not consistent with the "
"place of current op(%s).",
out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_[dev_id]);
}
}
}
}
#else
PADDLE_ENFORCE(!use_cuda);
#endif
@ -93,17 +121,48 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second);
// Dummy Variable is used to represent dependencies between operators, so
// there doesn't add event for it.
auto *in_var_handle = dynamic_cast<VarHandle *>(in_var);
if (in_var_handle) {
auto &place = in_var_handle->place();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
auto stream =
static_cast<platform::CUDADeviceContext *>(dev_ctxes_.at(place))
->stream();
PADDLE_ENFORCE(
cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#else
PADDLE_THROW("Doesn't compile the GPU.");
#endif
}
// There are nothing to do when the place is CPUPlace.
}
}
}
}
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(place));
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
// Dummy Variable is used to represent dependencies between operators, so
// there doesn't add event for it.
auto *in_var_handle = dynamic_cast<VarHandle *>(in_var);
if (in_var_handle) {
if (platform::is_gpu_place(in_var_handle->place())) {
#ifdef PADDLE_WITH_CUDA
auto stream = static_cast<platform::CUDADeviceContext *>(
dev_ctxes_.at(in_var_handle->place()))
->stream();
PADDLE_ENFORCE(
cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#else
PADDLE_THROW("Doesn't compile the GPU.");
#endif
}
// There are nothing to do when the place is CPUPlace.
}
}
}
}

@ -15,18 +15,20 @@
#pragma once
#include <deque>
#include <functional>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
@ -36,6 +38,12 @@ class Scope;
namespace details {
struct OpDependentData {
std::unordered_map<OpHandleBase *, size_t> pending_ops_;
std::unordered_set<VarHandleBase *> pending_vars_;
std::unordered_set<OpHandleBase *> ready_ops_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
@ -57,29 +65,35 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private:
ir::Graph *graph_;
std::unique_ptr<::ThreadPool> pool_;
::ThreadPool prepare_pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_;
std::atomic<int> running_ops_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
void InsertPendingVar(std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars,
std::unordered_set<VarHandleBase *> *ready_vars,
VarHandleBase *var) const;
void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars,
FeedFetchList *fetch_data);
void PrepareOpDeps();
void CopyOpDeps();
private:
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
std::unique_ptr<OpDependentData> op_deps_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
};

@ -43,6 +43,7 @@ struct VarHandleBase {
virtual ~VarHandleBase();
virtual std::string DebugString() const = 0;
virtual const std::string& Name() const = 0;
void AddInput(OpHandleBase* in, ir::Node* node) {
node_->inputs.clear();
@ -95,8 +96,6 @@ struct VarHandleBase {
//
// NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~VarHandle();
std::string DebugString() const override;
@ -109,6 +108,20 @@ struct VarHandle : public VarHandleBase {
name_(std::move(name)),
place_(std::move(place)) {}
#ifdef PADDLE_WITH_CUDA
bool HasEvent() { return has_event_; }
const cudaEvent_t& GetEvent() {
PADDLE_ENFORCE(HasEvent(), "The event is not set.");
return event_;
}
void SetGenerateEvent(const cudaEvent_t& event) {
has_event_ = true;
event_ = event;
}
#endif
// version field currently is not used, however, just store the version to
// debug easily.
private:
@ -116,6 +129,11 @@ struct VarHandle : public VarHandleBase {
size_t scope_idx_;
std::string name_;
platform::Place place_;
#ifdef PADDLE_WITH_CUDA
// Only when this event is triggered, var is generated.
cudaEvent_t event_;
bool has_event_{false};
#endif
public:
bool IsTheSameVar(const VarHandle& o) const {
@ -125,6 +143,7 @@ struct VarHandle : public VarHandleBase {
size_t version() const { return version_; }
size_t scope_idx() const { return scope_idx_; }
const std::string& Name() const override { return name_; }
const std::string& name() const { return name_; }
const platform::Place& place() const { return place_; }
};
@ -136,6 +155,10 @@ struct DummyVarHandle : public VarHandleBase {
virtual ~DummyVarHandle();
std::string DebugString() const override;
public:
const std::string& Name() const override { return name_; }
std::string name_{"DummyVar"};
};
} // namespace details

Loading…
Cancel
Save