|
|
@ -30,23 +30,45 @@ operator run on each GPU, it will automatically sync with different streams when
|
|
|
|
// if op's input is params' grad:
|
|
|
|
// if op's input is params' grad:
|
|
|
|
// sync with allreduce stream
|
|
|
|
// sync with allreduce stream
|
|
|
|
// e.g. sgd should wait for allreduce to be finished
|
|
|
|
// e.g. sgd should wait for allreduce to be finished
|
|
|
|
SyncMultipleStreams(op);
|
|
|
|
CallBack->BeforeOp(op);
|
|
|
|
|
|
|
|
|
|
|
|
op->Run(*local_scope, place_);
|
|
|
|
op->Run(*local_scope, place_);
|
|
|
|
|
|
|
|
|
|
|
|
// if op's output is params' grad:
|
|
|
|
// if op's output is params' grad:
|
|
|
|
// sync with computation stream
|
|
|
|
// sync with computation stream
|
|
|
|
// e.g. allreduce shoudl wait for fc_grad to be finished.
|
|
|
|
// e.g. allreduce shoudl wait for fc_grad to be finished.
|
|
|
|
SyncMultipleStreams(op);
|
|
|
|
CallBack->AfterOp(op);
|
|
|
|
```
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
And the `Callback` object can be implemented as the following
|
|
|
|
|
|
|
|
|
|
|
|
## API
|
|
|
|
```c++
|
|
|
|
|
|
|
|
struct AllReduceCallBack {
|
|
|
|
|
|
|
|
void BeforeOp(framework::OperatorBase* op);
|
|
|
|
|
|
|
|
void AfterOp(framework::OperatorBase* op);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> reduced_param_grad_names;
|
|
|
|
|
|
|
|
std::unordered_set<std::string> param_grad_names_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
platform::DeviceContext* computation_dev_ctx; // computation device context
|
|
|
|
|
|
|
|
platform::DeviceContext* communication_dev_ctx; // communication device context
|
|
|
|
|
|
|
|
|
|
|
|
The `ParallelExecutor.run` has similar interface as `Executor.run`. Besides
|
|
|
|
framework::Scope* scope;
|
|
|
|
1. Scope: we don't expose `scope` in `ParallelExecutor.run` since `ParallelExecutor` has its
|
|
|
|
platform::NCCL::Communicator* nccl_com;
|
|
|
|
own scope to maintain NCCL.
|
|
|
|
};
|
|
|
|
1. Feed: we don't expose `feed` in the API either, because the whole point of implementing
|
|
|
|
|
|
|
|
parallel_executor is the speed. The input for NN should be implemented in an reader OP.
|
|
|
|
AllReduceCallBack::BeforeOp(framework::OperatorBase* op) {
|
|
|
|
1. Fetch: we return the fetched value on all GPUs as a list. (e.g. `exe.run(..., fetch=loss)`
|
|
|
|
if (op->Input() in reduced_param_grad_names) {
|
|
|
|
with return `[loss_on_gpu0, loss_on_gpu1]`)
|
|
|
|
communication_dev_ctx->Wait();
|
|
|
|
|
|
|
|
reduced_param_grad_names.erase(op->Input())
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AllReduceCallBack::AfterOp(framework::OperatorBase* op) {
|
|
|
|
|
|
|
|
if (op->Output() in param_grad_names) {
|
|
|
|
|
|
|
|
computation_dev_ctx->Wait();
|
|
|
|
|
|
|
|
reduced_param_grad_names.insert(op->Output());
|
|
|
|
|
|
|
|
ncclAllreduce(scope, op->Output(), communication_dev_ctx);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
```
|
|
|
|