|
|
|
@ -394,11 +394,79 @@ bool PsCacheManager::ProcessData() {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
|
|
|
|
bool *in_device, size_t *hash_hit_count) {
|
|
|
|
|
MS_ERROR_IF_NULL(batch_ids);
|
|
|
|
|
MS_ERROR_IF_NULL(hash_index);
|
|
|
|
|
MS_ERROR_IF_NULL(in_device);
|
|
|
|
|
MS_ERROR_IF_NULL(hash_hit_count);
|
|
|
|
|
MS_ERROR_IF_NULL(embedding_device_cache_);
|
|
|
|
|
auto &device_hash_map = embedding_device_cache_->device_hash_map_;
|
|
|
|
|
MS_ERROR_IF_NULL(device_hash_map);
|
|
|
|
|
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batch_ids_len; ++i) {
|
|
|
|
|
auto iter = hash_id_to_index.find(batch_ids[i]);
|
|
|
|
|
if (iter != hash_id_to_index.end()) {
|
|
|
|
|
hash_index[i] = iter->second;
|
|
|
|
|
if (device_hash_map->hash_step(iter->second) != data_step_) {
|
|
|
|
|
++(*hash_hit_count);
|
|
|
|
|
device_hash_map->set_hash_step(iter->second, data_step_);
|
|
|
|
|
}
|
|
|
|
|
in_device[i] = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
|
|
|
|
bool *in_device) {
|
|
|
|
|
MS_ERROR_IF_NULL(batch_ids);
|
|
|
|
|
MS_ERROR_IF_NULL(hash_index);
|
|
|
|
|
MS_ERROR_IF_NULL(in_device);
|
|
|
|
|
|
|
|
|
|
size_t thread_num = batch_ids_len / kMinIdsPerThread + 1;
|
|
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
|
|
|
|
std::thread threads[kMaxThreadNum];
|
|
|
|
|
size_t hash_hit_count[kMaxThreadNum] = {0};
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
size_t task_offset = 0;
|
|
|
|
|
|
|
|
|
|
for (; i < thread_num; ++i) {
|
|
|
|
|
if (task_offset >= batch_ids_len) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
|
|
|
|
|
threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens,
|
|
|
|
|
hash_index + task_offset, in_device + task_offset, hash_hit_count + i);
|
|
|
|
|
task_offset += task_proc_lens;
|
|
|
|
|
}
|
|
|
|
|
if (task_offset != batch_ids_len) {
|
|
|
|
|
MS_LOG(WARNING) << "Ps cache check id in device inadequate, total:" << batch_ids_len << " checked:" << task_offset;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < i; j++) {
|
|
|
|
|
threads[j].join();
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = 0; j < i; j++) {
|
|
|
|
|
statistics_info_.hash_hit_count_ += hash_hit_count[j];
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
|
|
|
|
|
MS_ERROR_IF_NULL(batch_ids);
|
|
|
|
|
MS_ERROR_IF_NULL(hash_index);
|
|
|
|
|
statistics_info_.batch_id_count_ = batch_ids_len;
|
|
|
|
|
std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]);
|
|
|
|
|
if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Data in device memset failed.";
|
|
|
|
|
}
|
|
|
|
|
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
|
|
|
|
|
for (size_t i = 0; i < batch_ids_len; i++) {
|
|
|
|
|
if (in_device[i]) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
bool need_swap_host_to_device = true;
|
|
|
|
|
bool need_swap_device_to_host = true;
|
|
|
|
|
auto id = batch_ids[i];
|
|
|
|
@ -585,10 +653,10 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
|
|
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
|
|
|
|
std::thread threads[kMaxThreadNum];
|
|
|
|
|
size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num;
|
|
|
|
|
size_t i;
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
size_t task_offset = 0;
|
|
|
|
|
MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens;
|
|
|
|
|
for (i = 0; i < thread_num; i++) {
|
|
|
|
|
for (; i < thread_num; i++) {
|
|
|
|
|
if (task_offset >= indices_lens) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -613,7 +681,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
|
|
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
|
|
|
|
std::thread threads[kMaxThreadNum];
|
|
|
|
|
size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num;
|
|
|
|
|
size_t i;
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
size_t task_offset = 0;
|
|
|
|
|
|
|
|
|
|
auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
|
|
|
|
@ -632,7 +700,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (i = 0; i < thread_num; i++) {
|
|
|
|
|
for (; i < thread_num; i++) {
|
|
|
|
|
if (task_offset >= insert_indices_size) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|