|
|
|
@ -95,10 +95,14 @@ template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
HeterComm<KeyType, ValType, GradType>::HeterComm(
|
|
|
|
|
size_t capacity, std::shared_ptr<HeterPsResource> resource) {
|
|
|
|
|
resource_ = resource;
|
|
|
|
|
storage_.resize(resource_->total_gpu());
|
|
|
|
|
for (int i = 0; i < resource_->total_gpu(); ++i) {
|
|
|
|
|
platform::CUDADeviceGuard guard(resource_->dev_id(i));
|
|
|
|
|
auto table = new Table(capacity / load_factor_);
|
|
|
|
|
tables_.push_back(table);
|
|
|
|
|
if (multi_node_) {
|
|
|
|
|
storage_[i].init(feanum_, resource_->dev_id(i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
init_path();
|
|
|
|
|
}
|
|
|
|
@ -595,6 +599,186 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
template <typename Sgd>
|
|
|
|
|
void HeterComm<KeyType, ValType, GradType>::update_one_table(
|
|
|
|
|
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
|
|
|
|
|
if (len == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int dev_id = resource_->dev_id(gpu_num);
|
|
|
|
|
platform::CUDADeviceGuard guard(dev_id);
|
|
|
|
|
tables_[gpu_num]->update(d_keys, d_grads, len, sgd,
|
|
|
|
|
resource_->remote_stream(gpu_num));
|
|
|
|
|
cudaStreamSynchronize(resource_->remote_stream(gpu_num));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
template <typename Sgd>
|
|
|
|
|
void HeterComm<KeyType, ValType, GradType>::push_sparse_multi_node(
|
|
|
|
|
int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, Sgd& sgd) {
|
|
|
|
|
if (len == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int uniq_len = len;
|
|
|
|
|
merge_grad(gpu_num, d_keys, d_grads, len, uniq_len);
|
|
|
|
|
|
|
|
|
|
uniq_len = gather_one_node_grad(gpu_num, d_keys, d_grads, uniq_len);
|
|
|
|
|
|
|
|
|
|
uniq_len = gather_multi_node_grad(gpu_num, storage_[gpu_num].local_keys,
|
|
|
|
|
storage_[gpu_num].local_grads, uniq_len);
|
|
|
|
|
|
|
|
|
|
update_one_table(gpu_num, storage_[gpu_num].local_keys,
|
|
|
|
|
storage_[gpu_num].local_grads, uniq_len, sgd);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
|
|
|
|
|
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
|
|
|
|
|
int total_gpu = resource_->total_gpu();
|
|
|
|
|
int dev_id = resource_->dev_id(gpu_num);
|
|
|
|
|
auto& storage = storage_[gpu_num];
|
|
|
|
|
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
|
|
|
|
|
platform::CUDADeviceGuard guard(dev_id);
|
|
|
|
|
auto stream = resource_->local_stream(gpu_num, 0);
|
|
|
|
|
int max_size = 0;
|
|
|
|
|
|
|
|
|
|
ncclComm_t nccl_inner_comm = nccl_inner_comms_[gpu_num];
|
|
|
|
|
// alloc for size
|
|
|
|
|
int h_node_len[total_gpu];
|
|
|
|
|
auto d_node_len_mem = memory::AllocShared(place, total_gpu * sizeof(int));
|
|
|
|
|
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
|
|
|
|
|
h_node_len[gpu_num] = len;
|
|
|
|
|
|
|
|
|
|
cudaMemcpy(d_node_len + gpu_num, h_node_len + gpu_num, sizeof(int),
|
|
|
|
|
cudaMemcpyHostToDevice);
|
|
|
|
|
|
|
|
|
|
// allgather grad len
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
|
|
|
|
|
(const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt,
|
|
|
|
|
nccl_inner_comm, stream));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
|
|
|
|
|
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu,
|
|
|
|
|
cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
if (h_node_len[i] > max_size) {
|
|
|
|
|
max_size = h_node_len[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
storage.alloc(max_size * total_gpu);
|
|
|
|
|
|
|
|
|
|
// allgather keys and grads
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
|
|
|
|
|
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inner_comm, stream));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
|
|
|
|
|
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
|
|
|
|
|
nccl_inner_comm, stream));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
|
|
|
|
|
|
|
|
|
|
int h_left[total_gpu];
|
|
|
|
|
int h_right[total_gpu];
|
|
|
|
|
auto d_left = memory::AllocShared(place, total_gpu * sizeof(int));
|
|
|
|
|
auto d_right = memory::AllocShared(place, total_gpu * sizeof(int));
|
|
|
|
|
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
|
|
|
|
|
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
|
|
|
|
|
|
|
|
|
|
int merge_num = 0;
|
|
|
|
|
for (int i = 0; i < total_gpu; ++i) {
|
|
|
|
|
int index = i * max_size;
|
|
|
|
|
auto d_idx = memory::AllocShared(place, h_node_len[i] * sizeof(int));
|
|
|
|
|
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
|
|
|
|
|
|
|
|
|
|
cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int));
|
|
|
|
|
cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int));
|
|
|
|
|
|
|
|
|
|
split_input_to_shard(storage.all_keys + index, d_idx_ptr, h_node_len[i],
|
|
|
|
|
d_left_ptr, d_right_ptr, gpu_num);
|
|
|
|
|
cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
|
|
|
|
|
cudaMemcpyDeviceToHost);
|
|
|
|
|
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
|
|
|
|
|
cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
|
|
int grid_size = (h_node_len[i] - 1) / block_size_ + 1;
|
|
|
|
|
fill_shard_grads<<<grid_size, block_size_, 0, stream>>>(
|
|
|
|
|
storage.local_keys + merge_num, storage.all_keys + index,
|
|
|
|
|
storage.local_grads + merge_num, storage.all_grads + index,
|
|
|
|
|
d_idx_ptr + h_left[gpu_num], h_right[gpu_num] - h_left[gpu_num] + 1);
|
|
|
|
|
merge_num = merge_num + h_right[gpu_num] - h_left[gpu_num] + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ret = merge_num;
|
|
|
|
|
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
|
|
|
|
|
int gpu_num, KeyType* d_keys, GradType* d_grads, int len) {
|
|
|
|
|
int dev_id = resource_->dev_id(gpu_num);
|
|
|
|
|
auto& storage = storage_[gpu_num];
|
|
|
|
|
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
|
|
|
|
|
platform::CUDADeviceGuard guard(dev_id);
|
|
|
|
|
auto stream = resource_->local_stream(gpu_num, 0);
|
|
|
|
|
int max_size = 0;
|
|
|
|
|
ncclComm_t nccl_inter_comm = nccl_inter_comms_[gpu_num];
|
|
|
|
|
// alloc for size
|
|
|
|
|
int h_node_len[node_size_];
|
|
|
|
|
auto d_node_len_mem = memory::AllocShared(place, node_size_ * sizeof(int));
|
|
|
|
|
int* d_node_len = reinterpret_cast<int*>(d_node_len_mem->ptr());
|
|
|
|
|
h_node_len[0] = len;
|
|
|
|
|
|
|
|
|
|
cudaMemcpy(d_node_len, h_node_len, sizeof(int), cudaMemcpyHostToDevice);
|
|
|
|
|
|
|
|
|
|
// allgather grad len
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
|
|
|
|
|
d_node_len, d_node_len, 1, ncclInt, nccl_inter_comm, stream));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
|
|
|
|
|
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * node_size_,
|
|
|
|
|
cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < node_size_; ++i) {
|
|
|
|
|
if (h_node_len[i] > max_size) {
|
|
|
|
|
max_size = h_node_len[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
storage.alloc(max_size * node_size_);
|
|
|
|
|
|
|
|
|
|
// allgather keys and grads
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
|
|
|
|
|
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inter_comm, stream));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
|
|
|
|
|
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
|
|
|
|
|
nccl_inter_comm, stream));
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
|
|
|
|
|
|
|
|
|
|
int merge_num = 0;
|
|
|
|
|
for (int i = 0; i < node_size_; ++i) {
|
|
|
|
|
int index = i * max_size;
|
|
|
|
|
cudaMemcpyAsync(storage.local_keys + merge_num, storage.all_keys + index,
|
|
|
|
|
h_node_len[i], cudaMemcpyDefault, stream);
|
|
|
|
|
cudaMemcpyAsync(storage.local_grads + merge_num, storage.all_grads + index,
|
|
|
|
|
h_node_len[i], cudaMemcpyDefault, stream);
|
|
|
|
|
merge_num += h_node_len[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ret = merge_num;
|
|
|
|
|
merge_grad(gpu_num, storage.local_keys, storage.local_grads, merge_num, ret);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
|
|
|
void HeterComm<KeyType, ValType, GradType>::end_pass() {
|
|
|
|
|
int total_gpu = resource_->total_gpu();
|
|
|
|
|