You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
188 lines
6.0 KiB
188 lines
6.0 KiB
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
#include <thread>
|
|
#include <vector>
|
|
#include "cub/cub.cuh"
|
|
#include "hashtable.h"
|
|
#include "heter_resource.h"
|
|
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
|
|
#include "paddle/fluid/memory/memory.h"
|
|
#include "paddle/fluid/platform/cuda_device_guard.h"
|
|
#include "paddle/fluid/platform/dynload/nccl.h"
|
|
#include "paddle/fluid/platform/place.h"
|
|
#include "thrust/pair.h"
|
|
|
|
#ifdef PADDLE_WITH_PSLIB
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
struct CustomGradMerger {
|
|
template <typename T>
|
|
CUB_RUNTIME_FUNCTION __forceinline__ __device__ T
|
|
operator()(const T& a, const T& b) const {
|
|
T out;
|
|
out.slot = a.slot;
|
|
out.show = a.show + b.show;
|
|
out.clk = a.clk + b.clk;
|
|
out.lr_g = a.lr_g + b.lr_g;
|
|
for (int i = 0; i < MF_DIM; ++i) {
|
|
out.mf_g[i] = a.mf_g[i] + b.mf_g[i];
|
|
}
|
|
return out;
|
|
}
|
|
};
|
|
|
|
template <typename KeyType, typename ValType, typename GradType>
|
|
class HeterComm {
|
|
public:
|
|
HeterComm(size_t capacity, std::shared_ptr<HeterPsResource> resource);
|
|
virtual ~HeterComm();
|
|
HeterComm(const HeterComm&) = delete;
|
|
HeterComm& operator=(const HeterComm&) = delete;
|
|
|
|
void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len,
|
|
int* left, int* right, int gpu_num);
|
|
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
|
|
int& uniq_len);
|
|
void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len);
|
|
void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len,
|
|
size_t chunk_size, int stream_num);
|
|
void dump();
|
|
void show_one_table(int gpu_num);
|
|
int get_index_by_devid(int devid);
|
|
|
|
template <typename Sgd>
|
|
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
|
|
Sgd& sgd);
|
|
|
|
template <typename Sgd>
|
|
void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
|
|
size_t len, Sgd& sgd);
|
|
|
|
template <typename Sgd>
|
|
void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len,
|
|
Sgd& sgd);
|
|
|
|
int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads,
|
|
int len);
|
|
|
|
int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads,
|
|
int len);
|
|
|
|
int log2i(int x);
|
|
|
|
void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
|
|
const std::vector<ncclComm_t>& inter_comms,
|
|
int comm_size) {
|
|
nccl_inner_comms_ = inner_comms;
|
|
nccl_inter_comms_ = inter_comms;
|
|
node_size_ = comm_size;
|
|
}
|
|
|
|
bool need_transfer(int send_id, int receive_id) {
|
|
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
|
|
}
|
|
|
|
// void dump_to_cpu(int index);
|
|
|
|
void end_pass();
|
|
|
|
int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }
|
|
|
|
struct Node {
|
|
cudaStream_t in_stream;
|
|
cudaStream_t out_stream;
|
|
char* key_storage;
|
|
char* val_storage;
|
|
int sync;
|
|
int key_bytes_len;
|
|
int val_bytes_len;
|
|
int gpu_num;
|
|
};
|
|
|
|
struct Path {
|
|
std::vector<Node> nodes_;
|
|
};
|
|
|
|
struct LocalStorage {
|
|
LocalStorage() {}
|
|
void init(int size, int dev_id) {
|
|
place_ = platform::CUDAPlace(dev_id);
|
|
alloc(size, true);
|
|
}
|
|
|
|
void alloc(int size, bool force = false) {
|
|
if (force || size > all_keys_mem->size()) {
|
|
all_keys_mem.reset();
|
|
all_grads_mem.reset();
|
|
all_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
|
|
all_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
|
|
all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
|
|
all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
|
|
}
|
|
if (force || size > local_keys_mem->size()) {
|
|
local_keys_mem.reset();
|
|
local_grads_mem.reset();
|
|
local_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
|
|
local_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
|
|
local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
|
|
local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
|
|
}
|
|
}
|
|
|
|
platform::CUDAPlace place_;
|
|
std::shared_ptr<memory::Allocation> all_keys_mem;
|
|
std::shared_ptr<memory::Allocation> all_grads_mem;
|
|
KeyType* all_keys;
|
|
GradType* all_grads;
|
|
|
|
std::shared_ptr<memory::Allocation> local_keys_mem;
|
|
std::shared_ptr<memory::Allocation> local_grads_mem;
|
|
KeyType* local_keys;
|
|
GradType* local_grads;
|
|
};
|
|
|
|
void init_path();
|
|
void create_storage(
|
|
int start_index, int end_index, int keylen, int vallen,
|
|
std::vector<std::shared_ptr<memory::Allocation>>& local_strorage);
|
|
void walk_to_src(int start_index, int end_index, char* src_val);
|
|
void walk_to_dest(int start_index, int end_index, char* src_key,
|
|
char* src_val);
|
|
|
|
private:
|
|
using Table = HashTable<KeyType, ValType>;
|
|
int block_size_{256};
|
|
float load_factor_{0.75};
|
|
std::vector<Table*> tables_;
|
|
std::shared_ptr<HeterPsResource> resource_;
|
|
CustomGradMerger merger_;
|
|
int topo_aware_{1};
|
|
std::vector<std::vector<Path>> path_;
|
|
std::vector<LocalStorage> storage_;
|
|
int feanum_{1800 * 2048};
|
|
int multi_node_{1};
|
|
std::vector<ncclComm_t> nccl_inner_comms_;
|
|
std::vector<ncclComm_t> nccl_inter_comms_;
|
|
int node_size_;
|
|
};
|
|
|
|
} // end namespace framework
|
|
} // end namespace paddle
|
|
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
|
|
#endif
|