|
|
|
@ -369,6 +369,10 @@ bool PsCacheManager::ProcessData() {
|
|
|
|
|
// Get hash swap in/out index and ids.
|
|
|
|
|
RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get()));
|
|
|
|
|
DumpStatisticsInfo();
|
|
|
|
|
if ((device_need_wait_graph_ || host_need_wait_graph_) && (!WaitGraphRun())) {
|
|
|
|
|
MS_LOG(ERROR) << "Ps cache wait graph finish failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (const auto &item : hash_tables_) {
|
|
|
|
|
auto key = worker.GetParamKey(item.first);
|
|
|
|
|
auto hash_info = item.second;
|
|
|
|
@ -454,6 +458,20 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::ResetEmbeddingHashMap() {
|
|
|
|
|
MS_ERROR_IF_NULL(embedding_device_cache_);
|
|
|
|
|
const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
|
|
|
|
|
MS_ERROR_IF_NULL(device_hash_map);
|
|
|
|
|
MS_ERROR_IF_NULL(embedding_host_cache_);
|
|
|
|
|
const auto &host_hash_map = embedding_host_cache_->host_hash_map_;
|
|
|
|
|
MS_ERROR_IF_NULL(host_hash_map);
|
|
|
|
|
device_hash_map->Reset();
|
|
|
|
|
host_hash_map->Reset();
|
|
|
|
|
device_need_wait_graph_ = false;
|
|
|
|
|
host_need_wait_graph_ = false;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
|
|
|
|
|
MS_ERROR_IF_NULL(batch_ids);
|
|
|
|
|
MS_ERROR_IF_NULL(hash_index);
|
|
|
|
@ -463,6 +481,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
|
|
|
|
|
MS_LOG(EXCEPTION) << "Data in device memset failed.";
|
|
|
|
|
}
|
|
|
|
|
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
|
|
|
|
|
RETURN_IF_FALSE(ResetEmbeddingHashMap());
|
|
|
|
|
for (size_t i = 0; i < batch_ids_len; i++) {
|
|
|
|
|
if (in_device[i]) {
|
|
|
|
|
continue;
|
|
|
|
@ -529,7 +548,7 @@ bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host,
|
|
|
|
|
auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
|
|
|
|
|
while (true) {
|
|
|
|
|
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
|
|
|
|
|
&(statistics_info_.device_to_host_size_));
|
|
|
|
|
&(statistics_info_.device_to_host_size_), &device_need_wait_graph_);
|
|
|
|
|
if (index == INVALID_INDEX_VALUE) {
|
|
|
|
|
if (!WaitGraphRun()) {
|
|
|
|
|
return false;
|
|
|
|
@ -570,8 +589,9 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
|
|
|
|
|
MS_ERROR_IF_NULL(server_to_host_index);
|
|
|
|
|
MS_ERROR_IF_NULL(server_to_host_ids);
|
|
|
|
|
while (true) {
|
|
|
|
|
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
|
|
|
|
|
graph_running_step_, &statistics_info_.host_to_server_size_);
|
|
|
|
|
auto index =
|
|
|
|
|
host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_,
|
|
|
|
|
&statistics_info_.host_to_server_size_, &host_need_wait_graph_);
|
|
|
|
|
if (index == INVALID_INDEX_VALUE) {
|
|
|
|
|
RETURN_IF_FALSE(WaitGraphRun());
|
|
|
|
|
continue;
|
|
|
|
@ -607,8 +627,9 @@ bool PsCacheManager::ParseHostDataDeviceToHost() {
|
|
|
|
|
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
|
|
|
|
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
|
|
|
|
while (true) {
|
|
|
|
|
auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids,
|
|
|
|
|
data_step_, graph_running_step_, &statistics_info_.host_to_server_size_);
|
|
|
|
|
auto index =
|
|
|
|
|
host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step_,
|
|
|
|
|
graph_running_step_, &statistics_info_.host_to_server_size_, &host_need_wait_graph_);
|
|
|
|
|
if (index == INVALID_INDEX_VALUE) {
|
|
|
|
|
RETURN_IF_FALSE(WaitGraphRun());
|
|
|
|
|
continue;
|
|
|
|
|