|
|
@ -331,12 +331,7 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
|
|
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::Finalize() {
|
|
|
|
void PsCacheManager::Finalize() {
|
|
|
|
if (running_) {
|
|
|
|
if (running_) {
|
|
|
|
if (!SyncHostEmbeddingTable()) {
|
|
|
|
SyncEmbeddingTable();
|
|
|
|
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!SyncDeviceEmbeddingTable()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
running_ = false;
|
|
|
|
running_ = false;
|
|
|
|
PsDataPrefetch::GetInstance().NotifyFinalize();
|
|
|
|
PsDataPrefetch::GetInstance().NotifyFinalize();
|
|
|
@ -846,6 +841,19 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::SyncEmbeddingTable() {
|
|
|
|
|
|
|
|
if (finish_embedding_table_sync_) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!SyncHostEmbeddingTable()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!SyncDeviceEmbeddingTable()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
finish_embedding_table_sync_ = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool PsCacheManager::SyncHostEmbeddingTable() {
|
|
|
|
bool PsCacheManager::SyncHostEmbeddingTable() {
|
|
|
|
MS_ERROR_IF_NULL(embedding_host_cache_);
|
|
|
|
MS_ERROR_IF_NULL(embedding_host_cache_);
|
|
|
|
const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();
|
|
|
|
const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();
|
|
|
|