From 8bddfba9e24d1b3890efabeac4c1151b75f2945e Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Mon, 8 Feb 2021 19:08:49 +0800 Subject: [PATCH] fix ps cache process data thread can not exit --- mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc | 5 ++++- .../ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc | 15 ++++++++++++--- .../ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h | 3 ++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index d82d0c3ce3..96e4e23c30 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -353,7 +353,10 @@ bool PsCacheManager::ProcessData() { struct timeval start_time, end_time; const uint64_t kUSecondInSecond = 1000000; (void)gettimeofday(&start_time, nullptr); - auto data = PsDataPrefetch::GetInstance().data(channel_name_); + void *data = nullptr; + if (!PsDataPrefetch::GetInstance().QueryData(channel_name_, &data)) { + return false; + } if (data == nullptr) { MS_LOG(INFO) << "No data process, channel name:" << channel_name_; std::unique_lock locker(data_mutex_); diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc index e0a9d0ab01..25ef413751 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc @@ -53,6 +53,7 @@ bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, c const std::string supported_data_type = "int32"; if (data_type != supported_data_type) { MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]"; + invalid_data_type_ = true; return false; } if (data == nullptr) { @@ -105,12 +106,20 @@ bool PsDataPrefetch::FinalizeData(const std::string &channel_name) { return false; } -void *PsDataPrefetch::data(const std::string &channel_name) const { +bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const { + if (invalid_data_type_) { + return false; + } + if (data_ptr == nullptr) { + return false; + } auto channel = ps_data_channel(channel_name); if (channel == nullptr) { - return nullptr; + *data_ptr = nullptr; + return true; } - return channel->data(); + *data_ptr = channel->data(); + return true; } size_t PsDataPrefetch::data_size(const std::string &channel_name) const { diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h index 256a6475dd..02d3fa7887 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h @@ -41,7 +41,7 @@ class EXPORT PsDataPrefetch { const std::string &data_type); EXPORT bool FinalizeData(const std::string &channel_name); EXPORT void NotifyFinalize(); - EXPORT void *data(const std::string &channel_name) const; + EXPORT bool QueryData(const std::string &channel_name, void **data_ptr) const; EXPORT size_t data_size(const std::string &channel_name) const; EXPORT bool TryWakeChannel(const std::string &channel_name); @@ -59,6 +59,7 @@ class EXPORT PsDataPrefetch { std::condition_variable data_prefetch_; std::condition_variable data_process_; std::atomic_bool need_wait_{true}; + std::atomic_bool invalid_data_type_{false}; }; } // namespace ps } // namespace mindspore