diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index d55c56dab9..69982e2ba3 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -331,12 +331,7 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { void PsCacheManager::Finalize() { if (running_) { - if (!SyncHostEmbeddingTable()) { - MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; - } - if (!SyncDeviceEmbeddingTable()) { - MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; - } + SyncEmbeddingTable(); } running_ = false; PsDataPrefetch::GetInstance().NotifyFinalize(); @@ -846,6 +841,19 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_da return true; } +void PsCacheManager::SyncEmbeddingTable() { + if (finish_embedding_table_sync_) { + return; + } + if (!SyncHostEmbeddingTable()) { + MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; + } + if (!SyncDeviceEmbeddingTable()) { + MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; + } + finish_embedding_table_sync_ = true; +} + bool PsCacheManager::SyncHostEmbeddingTable() { MS_ERROR_IF_NULL(embedding_host_cache_); const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index(); diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index 5e1732d08d..7dca6b0159 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -127,6 +127,7 @@ class PsCacheManager { bool initialized_ps_cache() const { return initialized_ps_cache_; } void DoProcessData(uint32_t device_id, void *context); void IncreaseGraphStep(const std::string &channel_name); + void SyncEmbeddingTable(); void Finalize(); void DumpHashTables(bool dump_device_tables = false) const; @@ -193,6 +194,7 @@ class PsCacheManager { std::atomic_bool finish_insert_init_info_{false}; std::atomic_bool finish_init_parameter_server_{false}; std::atomic_bool running_{false}; + bool finish_embedding_table_sync_{false}; }; static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc index c0395b3a43..249d557dc7 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -16,10 +16,16 @@ #include "runtime/device/kernel_runtime_manager.h" #include "utils/log_adapter.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "ps/ps_cache/ps_cache_manager.h" +#endif namespace mindspore { namespace device { void KernelRuntimeManager::ClearRuntimeResource() { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + ps::ps_cache_instance.SyncEmbeddingTable(); +#endif std::lock_guard guard(lock_); for (auto &iter : runtime_map_) { MS_LOG(INFO) << "Release device " << iter.first;