Wait by stream

helinwang-patch-1
Yu Yang 7 years ago
parent e8a7e5d1e6
commit b2c7a9b828

@ -77,7 +77,7 @@ struct OpHandle {
virtual ~OpHandle() {} virtual ~OpHandle() {}
virtual void Run() { PADDLE_THROW("Not implemented"); } virtual void Run() { PADDLE_THROW("Not implemented"); }
virtual void Wait() {} virtual void Wait(platform::DeviceContext *waited_dev) {}
}; };
struct ComputationOpHandle : public OpHandle { struct ComputationOpHandle : public OpHandle {
@ -97,13 +97,17 @@ struct ComputationOpHandle : public OpHandle {
auto *cur_ctx = dev_ctx_[place_]; auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
in->generated_op_->Wait(); in->generated_op_->Wait(cur_ctx);
} }
} }
op_->Run(*scope_, place_); op_->Run(*scope_, place_);
LOG(INFO) << "Done " << this; LOG(INFO) << "Done " << this;
} }
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(place_)->Wait();
}
}; };
struct ScaleLossGradOpHandle : public OpHandle { struct ScaleLossGradOpHandle : public OpHandle {
@ -136,6 +140,10 @@ struct ScaleLossGradOpHandle : public OpHandle {
->stream()); ->stream());
} }
} }
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(place_)->Wait();
}
}; };
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
@ -276,6 +284,10 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
} }
} }
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(waited_dev->GetPlace())->Wait();
}
}; };
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(

Loading…
Cancel
Save