!12286 [bugfix] ps cache process data thread can not exit when input data type is invalid

From: @zyli2020
Reviewed-by: @cristoval,@jjfeing
Signed-off-by: @cristoval
pull/12286/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit cbe4cc403e

@ -353,7 +353,10 @@ bool PsCacheManager::ProcessData() {
struct timeval start_time, end_time; struct timeval start_time, end_time;
const uint64_t kUSecondInSecond = 1000000; const uint64_t kUSecondInSecond = 1000000;
(void)gettimeofday(&start_time, nullptr); (void)gettimeofday(&start_time, nullptr);
auto data = PsDataPrefetch::GetInstance().data(channel_name_); void *data = nullptr;
if (!PsDataPrefetch::GetInstance().QueryData(channel_name_, &data)) {
return false;
}
if (data == nullptr) { if (data == nullptr) {
MS_LOG(INFO) << "No data process, channel name:" << channel_name_; MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
std::unique_lock<std::mutex> locker(data_mutex_); std::unique_lock<std::mutex> locker(data_mutex_);

@ -53,6 +53,7 @@ bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, c
const std::string supported_data_type = "int32"; const std::string supported_data_type = "int32";
if (data_type != supported_data_type) { if (data_type != supported_data_type) {
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]"; MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
invalid_data_type_ = true;
return false; return false;
} }
if (data == nullptr) { if (data == nullptr) {
@ -105,12 +106,20 @@ bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
return false; return false;
} }
void *PsDataPrefetch::data(const std::string &channel_name) const { bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
if (invalid_data_type_) {
return false;
}
if (data_ptr == nullptr) {
return false;
}
auto channel = ps_data_channel(channel_name); auto channel = ps_data_channel(channel_name);
if (channel == nullptr) { if (channel == nullptr) {
return nullptr; *data_ptr = nullptr;
return true;
} }
return channel->data(); *data_ptr = channel->data();
return true;
} }
size_t PsDataPrefetch::data_size(const std::string &channel_name) const { size_t PsDataPrefetch::data_size(const std::string &channel_name) const {

@ -41,7 +41,7 @@ class EXPORT PsDataPrefetch {
const std::string &data_type); const std::string &data_type);
EXPORT bool FinalizeData(const std::string &channel_name); EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize(); EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const; EXPORT bool QueryData(const std::string &channel_name, void **data_ptr) const;
EXPORT size_t data_size(const std::string &channel_name) const; EXPORT size_t data_size(const std::string &channel_name) const;
EXPORT bool TryWakeChannel(const std::string &channel_name); EXPORT bool TryWakeChannel(const std::string &channel_name);
@ -59,6 +59,7 @@ class EXPORT PsDataPrefetch {
std::condition_variable data_prefetch_; std::condition_variable data_prefetch_;
std::condition_variable data_process_; std::condition_variable data_process_;
std::atomic_bool need_wait_{true}; std::atomic_bool need_wait_{true};
std::atomic_bool invalid_data_type_{false};
}; };
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

Loading…
Cancel
Save