|
|
|
@ -146,6 +146,7 @@ struct ComputationOpHandle : public OpHandle {
|
|
|
|
|
auto *cur_ctx = dev_ctx_[place_];
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
|
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
|
|
|
|
|
VLOG(3) << "Wait " << in->generated_op_->DebugString();
|
|
|
|
|
in->generated_op_->Wait(cur_ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -163,13 +164,9 @@ struct ScaleLossGradOpHandle : public OpHandle {
|
|
|
|
|
platform::Place place)
|
|
|
|
|
: coeff_(static_cast<float>(1.0 / num_dev)),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
place_(place) {
|
|
|
|
|
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
|
|
|
|
|
}
|
|
|
|
|
place_(place) {}
|
|
|
|
|
|
|
|
|
|
~ScaleLossGradOpHandle() {
|
|
|
|
|
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
|
|
|
|
|
}
|
|
|
|
|
~ScaleLossGradOpHandle() {}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void RunImpl() override {
|
|
|
|
|