fix ps cache process data thread can not exit

pull/12286/head
lizhenyu 4 years ago
parent c9381e0448
commit 8bddfba9e2

@ -353,7 +353,10 @@ bool PsCacheManager::ProcessData() {
struct timeval start_time, end_time;
const uint64_t kUSecondInSecond = 1000000;
(void)gettimeofday(&start_time, nullptr);
auto data = PsDataPrefetch::GetInstance().data(channel_name_);
void *data = nullptr;
if (!PsDataPrefetch::GetInstance().QueryData(channel_name_, &data)) {
return false;
}
if (data == nullptr) {
MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
std::unique_lock<std::mutex> locker(data_mutex_);

@ -53,6 +53,7 @@ bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, c
const std::string supported_data_type = "int32";
if (data_type != supported_data_type) {
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
invalid_data_type_ = true;
return false;
}
if (data == nullptr) {
@ -105,12 +106,20 @@ bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
return false;
}
void *PsDataPrefetch::data(const std::string &channel_name) const {
bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
if (invalid_data_type_) {
return false;
}
if (data_ptr == nullptr) {
return false;
}
auto channel = ps_data_channel(channel_name);
if (channel == nullptr) {
return nullptr;
*data_ptr = nullptr;
return true;
}
return channel->data();
*data_ptr = channel->data();
return true;
}
size_t PsDataPrefetch::data_size(const std::string &channel_name) const {

@ -41,7 +41,7 @@ class EXPORT PsDataPrefetch {
const std::string &data_type);
EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const;
EXPORT bool QueryData(const std::string &channel_name, void **data_ptr) const;
EXPORT size_t data_size(const std::string &channel_name) const;
EXPORT bool TryWakeChannel(const std::string &channel_name);
@ -59,6 +59,7 @@ class EXPORT PsDataPrefetch {
std::condition_variable data_prefetch_;
std::condition_variable data_process_;
std::atomic_bool need_wait_{true};
std::atomic_bool invalid_data_type_{false};
};
} // namespace ps
} // namespace mindspore

Loading…
Cancel
Save