ps cache sync before release res

pull/10456/head
limingqi107 4 years ago
parent df15227cd3
commit 951570a089

@ -331,12 +331,7 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
void PsCacheManager::Finalize() { void PsCacheManager::Finalize() {
if (running_) { if (running_) {
if (!SyncHostEmbeddingTable()) { SyncEmbeddingTable();
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
}
if (!SyncDeviceEmbeddingTable()) {
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
}
} }
running_ = false; running_ = false;
PsDataPrefetch::GetInstance().NotifyFinalize(); PsDataPrefetch::GetInstance().NotifyFinalize();
@ -846,6 +841,19 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
return true; return true;
} }
void PsCacheManager::SyncEmbeddingTable() {
if (finish_embedding_table_sync_) {
return;
}
if (!SyncHostEmbeddingTable()) {
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
}
if (!SyncDeviceEmbeddingTable()) {
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
}
finish_embedding_table_sync_ = true;
}
bool PsCacheManager::SyncHostEmbeddingTable() { bool PsCacheManager::SyncHostEmbeddingTable() {
MS_ERROR_IF_NULL(embedding_host_cache_); MS_ERROR_IF_NULL(embedding_host_cache_);
const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index(); const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();

@ -127,6 +127,7 @@ class PsCacheManager {
bool initialized_ps_cache() const { return initialized_ps_cache_; } bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context); void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name); void IncreaseGraphStep(const std::string &channel_name);
void SyncEmbeddingTable();
void Finalize(); void Finalize();
void DumpHashTables(bool dump_device_tables = false) const; void DumpHashTables(bool dump_device_tables = false) const;
@ -193,6 +194,7 @@ class PsCacheManager {
std::atomic_bool finish_insert_init_info_{false}; std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false}; std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false}; std::atomic_bool running_{false};
bool finish_embedding_table_sync_{false};
}; };
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();

@ -16,10 +16,16 @@
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
namespace mindspore { namespace mindspore {
namespace device { namespace device {
void KernelRuntimeManager::ClearRuntimeResource() { void KernelRuntimeManager::ClearRuntimeResource() {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps::ps_cache_instance.SyncEmbeddingTable();
#endif
std::lock_guard<std::mutex> guard(lock_); std::lock_guard<std::mutex> guard(lock_);
for (auto &iter : runtime_map_) { for (auto &iter : runtime_map_) {
MS_LOG(INFO) << "Release device " << iter.first; MS_LOG(INFO) << "Release device " << iter.first;

Loading…
Cancel
Save