Rename code

helinwang-patch-1
Yu Yang 7 years ago
parent 9f4a98f397
commit 084cdd1f4f

@ -24,10 +24,10 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_(place) {}
void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctx_[place_];
auto *cur_ctx = dev_ctxes_[place_];
for (auto *in : inputs_) {
bool need_wait =
in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx;
in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx;
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}

@ -60,8 +60,8 @@ void FetchOpHandle::RunImpl() {
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
dev_ctx_[t.place()]->Wait();
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
dev_ctxes_[t.place()]->Wait();
#endif
} else {
tensors_[i].ShareDataWith(t);

@ -74,7 +74,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();

@ -23,7 +23,7 @@ NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
const platform::NCCLContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
for (auto &p : places_) {
this->dev_ctx_[p] = nccl_ctxs_.DevCtx(p);
this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p);
}
}
@ -34,7 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctx_[p]);
in->generated_op_->Wait(dev_ctxes_[p]);
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;

@ -42,7 +42,7 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_event) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(
@ -57,7 +57,7 @@ void OpHandleBase::Run(bool use_event) {
#ifdef PADDLE_WITH_CUDA
if (use_event) {
for (auto &p : dev_ctx_) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
@ -70,7 +70,7 @@ void OpHandleBase::Run(bool use_event) {
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctx_) {
for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait();
}
} else {
@ -81,7 +81,7 @@ void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
}
}
#else
for (auto &dev_ctx : dev_ctx_) {
for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait();
}
#endif

@ -31,7 +31,7 @@ class OpHandleBase {
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctx_;
dev_ctxes_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;

@ -21,7 +21,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
dev_ctx_[place_] = dev_ctx;
dev_ctxes_[place_] = dev_ctx;
}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
@ -38,7 +38,7 @@ void ScaleLossGradOpHandle::RunImpl() {
} else {
#ifdef PADDLE_WITH_CUDA
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_])
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);

@ -96,7 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// FIXME: Use new device context
for (auto &p : places_) {
op->dev_ctx_[p] = fetch_ctxs_.Get(p);
op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
}
for (auto *var : vars) {

Loading…
Cancel
Save