fix ps cache error print

pull/10477/head
limingqi107 4 years ago
parent 951570a089
commit 5227d6d16d

@ -856,6 +856,7 @@ void PsCacheManager::SyncEmbeddingTable() {
bool PsCacheManager::SyncHostEmbeddingTable() {
MS_ERROR_IF_NULL(embedding_host_cache_);
MS_ERROR_IF_NULL(embedding_host_cache_->host_hash_map_);
const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();
size_t swap_indices_lens = hash_id_to_index.size();
if (swap_indices_lens == 0) {
@ -899,6 +900,7 @@ bool PsCacheManager::SyncHostEmbeddingTable() {
bool PsCacheManager::SyncDeviceEmbeddingTable() {
MS_ERROR_IF_NULL(embedding_device_cache_);
const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
MS_ERROR_IF_NULL(device_hash_map);
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
size_t swap_indices_lens = hash_id_to_index.size();
if (swap_indices_lens == 0) {

@ -24,7 +24,9 @@ namespace mindspore {
namespace device {
void KernelRuntimeManager::ClearRuntimeResource() {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps::ps_cache_instance.SyncEmbeddingTable();
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.SyncEmbeddingTable();
}
#endif
std::lock_guard<std::mutex> guard(lock_);
for (auto &iter : runtime_map_) {

Loading…
Cancel
Save