|
|
|
@ -77,7 +77,7 @@ struct OpHandle {
|
|
|
|
|
virtual ~OpHandle() {}
|
|
|
|
|
|
|
|
|
|
virtual void Run() { PADDLE_THROW("Not implemented"); }
|
|
|
|
|
virtual void Wait() {}
|
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ComputationOpHandle : public OpHandle {
|
|
|
|
@ -97,13 +97,17 @@ struct ComputationOpHandle : public OpHandle {
|
|
|
|
|
auto *cur_ctx = dev_ctx_[place_];
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
|
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
|
|
|
|
|
in->generated_op_->Wait();
|
|
|
|
|
in->generated_op_->Wait(cur_ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_->Run(*scope_, place_);
|
|
|
|
|
LOG(INFO) << "Done " << this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
this->dev_ctx_.at(place_)->Wait();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ScaleLossGradOpHandle : public OpHandle {
|
|
|
|
@ -136,6 +140,10 @@ struct ScaleLossGradOpHandle : public OpHandle {
|
|
|
|
|
->stream());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
this->dev_ctx_.at(place_)->Wait();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
@ -276,6 +284,10 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
this->dev_ctx_.at(waited_dev->GetPlace())->Wait();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
ParallelExecutor::ParallelExecutor(
|
|
|
|
|