|
|
|
@ -17,10 +17,8 @@ class NCCLManager {
|
|
|
|
|
~NCCLManager() {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// clang-format off
|
|
|
|
|
std::vector<ncclComm_t> _comms;
|
|
|
|
|
std::vector<int> _gpu_worlds;
|
|
|
|
|
// clang-format on
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class NCCLContext : public DeviceContext {
|
|
|
|
@ -29,11 +27,9 @@ class NCCLContext : public DeviceContext {
|
|
|
|
|
virtual ~NCCLContext();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// clang-format off
|
|
|
|
|
std::vector<int> _gpu_ids;
|
|
|
|
|
std::vector<cudaStream_t> _streams;
|
|
|
|
|
int root_gpu;
|
|
|
|
|
// clang-format on
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|