|
|
|
@ -84,8 +84,8 @@ struct OpHandle {
|
|
|
|
|
|
|
|
|
|
virtual ~OpHandle() {}
|
|
|
|
|
|
|
|
|
|
virtual void Run() { PADDLE_THROW("Not implemented"); }
|
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev) {}
|
|
|
|
|
virtual void Run() = 0;
|
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev) = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ComputationOpHandle : public OpHandle {
|
|
|
|
@ -382,7 +382,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
|
|
|
|
|
VLOG(3) << "Invoke NCCL AllReduce";
|
|
|
|
|
int dtype = -1;
|
|
|
|
|
size_t numel = 0;
|
|
|
|
|
|
|
|
|
@ -848,7 +847,8 @@ void ParallelExecutor::RunOp(
|
|
|
|
|
LOG(FATAL) << "Unknown exception catched";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
member_->pool_.enqueue(op_run);
|
|
|
|
|
op_run();
|
|
|
|
|
// member_->pool_.enqueue(op_run);
|
|
|
|
|
}
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|