|
|
|
@ -35,11 +35,11 @@ void PsCacheManager::InsertHashTableSize(const std::string ¶m_name, size_t c
|
|
|
|
|
if (vocab_size_ == 0) {
|
|
|
|
|
vocab_size_ = vocab_size;
|
|
|
|
|
}
|
|
|
|
|
if (cache_vocab_size_ == 0) {
|
|
|
|
|
cache_vocab_size_ = cache_vocab_size;
|
|
|
|
|
if (vocab_cache_size_ == 0) {
|
|
|
|
|
vocab_cache_size_ = cache_vocab_size;
|
|
|
|
|
}
|
|
|
|
|
if (host_cache_vocab_size_ == 0) {
|
|
|
|
|
host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor;
|
|
|
|
|
if (host_vocab_cache_size_ == 0) {
|
|
|
|
|
host_vocab_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -148,8 +148,8 @@ void PsCacheManager::Initialize() {
|
|
|
|
|
Util::SetInternalEnvVar();
|
|
|
|
|
worker.Run();
|
|
|
|
|
}
|
|
|
|
|
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_);
|
|
|
|
|
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_);
|
|
|
|
|
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
|
|
|
|
|
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_);
|
|
|
|
|
AddEmbeddingTable();
|
|
|
|
|
AllocMemForHashTable();
|
|
|
|
|
SetLocalIdRank();
|
|
|
|
@ -220,13 +220,13 @@ void PsCacheManager::AllocMemForHashTable() {
|
|
|
|
|
for (auto &item : hash_tables_) {
|
|
|
|
|
size_t embedding_size = item.second.embedding_size;
|
|
|
|
|
auto &device_address = item.second.device_address;
|
|
|
|
|
device_address.size = cache_vocab_size_ * embedding_size * sizeof(float);
|
|
|
|
|
device_address.size = vocab_cache_size_ * embedding_size * sizeof(float);
|
|
|
|
|
auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(addr);
|
|
|
|
|
device_address.addr = addr;
|
|
|
|
|
|
|
|
|
|
auto &host_address = item.second.host_address;
|
|
|
|
|
auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size];
|
|
|
|
|
auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(host_address_ptr);
|
|
|
|
|
host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(host_address);
|
|
|
|
@ -239,21 +239,28 @@ void PsCacheManager::AllocMemForHashTable() {
|
|
|
|
|
embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
|
|
|
|
|
embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
|
|
|
|
|
if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) {
|
|
|
|
|
if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "MallocConstantMemory failed.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::SetLocalIdRank() {
|
|
|
|
|
auto worker_num = ::ps::NumWorkers();
|
|
|
|
|
auto worker_id = ::ps::MyRank();
|
|
|
|
|
auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
|
|
|
|
range_bound_.first = local_shard_size * worker_id;
|
|
|
|
|
range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_);
|
|
|
|
|
MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first
|
|
|
|
|
<< ", rank id end:" << range_bound_.second;
|
|
|
|
|
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
|
|
|
|
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
|
|
|
|
|
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
|
|
|
|
|
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
|
|
|
|
|
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
|
|
|
|
|
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
|
|
|
|
|
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
|
|
|
|
|
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
|
|
|
|
|
<< ", cache indices begin: " << cache_indices_bounds_.first
|
|
|
|
|
<< ", cache indices end: " << cache_indices_bounds_.second
|
|
|
|
|
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
|
|
|
|
|
|
|
|
|
|
std::string PsCacheManager::channel_name() {
|
|
|
|
|
std::lock_guard<std::mutex> locker(channel_mutex_);
|
|
|
|
|
return channel_name_;
|
|
|
|
@ -398,8 +405,8 @@ bool PsCacheManager::ProcessData() {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
|
|
|
|
bool *in_device, size_t *hash_hit_count) {
|
|
|
|
|
bool PsCacheManager::CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
|
|
|
|
bool *in_device, bool *out_range, size_t *hash_hit_count) {
|
|
|
|
|
MS_ERROR_IF_NULL(batch_ids);
|
|
|
|
|
MS_ERROR_IF_NULL(hash_index);
|
|
|
|
|
MS_ERROR_IF_NULL(in_device);
|
|
|
|
@ -410,9 +417,19 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc
|
|
|
|
|
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batch_ids_len; ++i) {
|
|
|
|
|
if (batch_ids[i] < emb_table_slice_bounds_.first) {
|
|
|
|
|
hash_index[i] = batch_ids[i] - vocab_cache_size_diff_;
|
|
|
|
|
out_range[i] = true;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (batch_ids[i] >= emb_table_slice_bounds_.second) {
|
|
|
|
|
hash_index[i] = batch_ids[i] + cache_indices_bounds_.second;
|
|
|
|
|
out_range[i] = true;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto iter = hash_id_to_index.find(batch_ids[i]);
|
|
|
|
|
if (iter != hash_id_to_index.end()) {
|
|
|
|
|
hash_index[i] = iter->second;
|
|
|
|
|
hash_index[i] = iter->second + cache_indices_bounds_.first;
|
|
|
|
|
if (device_hash_map->hash_step(iter->second) != data_step_) {
|
|
|
|
|
++(*hash_hit_count);
|
|
|
|
|
device_hash_map->set_hash_step(iter->second, data_step_);
|
|
|
|
@ -423,11 +440,12 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
|
|
|
|
bool *in_device) {
|
|
|
|
|
bool PsCacheManager::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
|
|
|
|
bool *in_device, bool *out_range) {
|
|
|
|
|
MS_ERROR_IF_NULL(batch_ids);
|
|
|
|
|
MS_ERROR_IF_NULL(hash_index);
|
|
|
|
|
MS_ERROR_IF_NULL(in_device);
|
|
|
|
|
MS_ERROR_IF_NULL(out_range);
|
|
|
|
|
|
|
|
|
|
size_t thread_num = batch_ids_len / kMinIdsPerThread + 1;
|
|
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
|
|
|
@ -441,8 +459,9 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
|
|
|
|
|
threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens,
|
|
|
|
|
hash_index + task_offset, in_device + task_offset, hash_hit_count + i);
|
|
|
|
|
threads[i] =
|
|
|
|
|
std::thread(&PsCacheManager::CheckCacheHitOrOutRangeTask, this, batch_ids + task_offset, task_proc_lens,
|
|
|
|
|
hash_index + task_offset, in_device + task_offset, out_range + task_offset, hash_hit_count + i);
|
|
|
|
|
task_offset += task_proc_lens;
|
|
|
|
|
}
|
|
|
|
|
if (task_offset != batch_ids_len) {
|
|
|
|
@ -477,27 +496,26 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
|
|
|
|
|
MS_ERROR_IF_NULL(hash_index);
|
|
|
|
|
statistics_info_.batch_id_count_ = batch_ids_len;
|
|
|
|
|
std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]);
|
|
|
|
|
std::unique_ptr<bool[]> out_range(new bool[batch_ids_len]);
|
|
|
|
|
if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Data in device memset failed.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Initialize in_device array failed.";
|
|
|
|
|
}
|
|
|
|
|
if (memset_s(out_range.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Initialize out_range array failed.";
|
|
|
|
|
}
|
|
|
|
|
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
|
|
|
|
|
RETURN_IF_FALSE(CheckCacheHitOrOutRange(batch_ids, batch_ids_len, hash_index, in_device.get(), out_range.get()));
|
|
|
|
|
RETURN_IF_FALSE(ResetEmbeddingHashMap());
|
|
|
|
|
for (size_t i = 0; i < batch_ids_len; i++) {
|
|
|
|
|
if (in_device[i]) {
|
|
|
|
|
if (in_device[i] || out_range[i]) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
bool need_swap_host_to_device = true;
|
|
|
|
|
bool need_swap_device_to_host = true;
|
|
|
|
|
auto id = batch_ids[i];
|
|
|
|
|
if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) {
|
|
|
|
|
hash_index[i] = -1;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
int index = INVALID_INDEX_VALUE;
|
|
|
|
|
RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index));
|
|
|
|
|
hash_index[i] = index;
|
|
|
|
|
RETURN_IF_FALSE(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index));
|
|
|
|
|
hash_index[i] = index + cache_indices_bounds_.first;
|
|
|
|
|
if (need_swap_host_to_device) {
|
|
|
|
|
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
|
|
|
|
|
RETURN_IF_FALSE(ParseHostDataHostToDevice(batch_ids[i]));
|
|
|
|
|
}
|
|
|
|
|
if (need_swap_device_to_host) {
|
|
|
|
|
RETURN_IF_FALSE(ParseHostDataDeviceToHost());
|
|
|
|
@ -667,7 +685,7 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size,
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
|
|
|
|
|
const int *indices_addr, float *output_addr) {
|
|
|
|
|
size_t first_dim_size = host_cache_vocab_size_;
|
|
|
|
|
size_t first_dim_size = host_vocab_cache_size_;
|
|
|
|
|
size_t outer_dim_size = embedding_size;
|
|
|
|
|
|
|
|
|
|
size_t thread_num = indices_lens / 10000 + 1;
|
|
|
|
@ -697,7 +715,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
|
|
|
|
|
float *insert_data, float *hash_table_addr) {
|
|
|
|
|
size_t first_dim_size = host_cache_vocab_size_;
|
|
|
|
|
size_t first_dim_size = host_vocab_cache_size_;
|
|
|
|
|
size_t thread_num = insert_indices_size / 10000 + 1;
|
|
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
|
|
|
|
std::thread threads[kMaxThreadNum];
|
|
|
|
|