ps cache data process thread support exit when exceptions occur

pull/10190/head
lizhenyu 4 years ago
parent d988e13fb5
commit 1f99cd7d86

@ -298,7 +298,9 @@ Status DeviceQueueOp::PushDataToGPU() {
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_);
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
}
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);

@ -55,7 +55,9 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
#if ENABLE_D
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_);
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) {
return FAILED;
}
}
#endif
if (tdt::TdtHostPushData(channel_name, items) != 0) {

@ -53,6 +53,7 @@
#include "ps/util.h"
#include "ps/worker.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/ps_cache_manager.h"
#endif
#if (ENABLE_GE || ENABLE_D)
@ -1083,9 +1084,10 @@ void ClearResAtexit() {
pynative::ClearPyNativeSession();
session::ClearPythonParasMap();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsParamServerMode()) {
if (ps::Util::IsRoleOfWorker()) {
ps::worker.Finalize();
if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) {
ps::worker.Finalize();
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.Finalize();
}
}
#endif

File diff suppressed because it is too large Load Diff

@ -49,17 +49,17 @@ class AscendPsCache : public PsCacheBasic {
public:
AscendPsCache() = default;
~AscendPsCache() override = default;
void InitDevice(uint32_t device_id, const void *context) override;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
void MallocConstantMemory(size_t constant_value) override;
void RecordEvent() override;
void SynchronizeEvent() override;
void SynchronizeStream() override;
void CopyHostMemToDevice(void *dst, void *src, size_t size) override;
void CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
bool MallocConstantMemory(size_t constant_value) override;
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_out_size) override;
void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_in_size) override;
private:

@ -25,67 +25,75 @@ namespace mindspore {
namespace ps {
namespace gpu {
MS_REG_PS_CACHE(kGPUDevice, GPUPsCache);
void GPUPsCache::InitDevice(uint32_t device_id, const void *) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed")
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
"Cuda create stream failed");
bool GPUPsCache::InitDevice(uint32_t device_id, const void *) {
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed")
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
"Cuda create stream failed");
return true;
}
void *GPUPsCache::MallocMemory(size_t size) {
return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
}
void GPUPsCache::RecordEvent() {
bool GPUPsCache::RecordEvent() {
event_.reset(new cudaEvent_t());
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda record event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda record event failed");
return true;
}
void GPUPsCache::SynchronizeEvent() {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
bool GPUPsCache::SynchronizeEvent() {
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
return true;
}
void GPUPsCache::SynchronizeStream() {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
"Cuda sync stream failed");
bool GPUPsCache::SynchronizeStream() {
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
"Cuda sync stream failed");
return true;
}
void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
MS_EXCEPTION_IF_NULL(dst);
MS_EXCEPTION_IF_NULL(src);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda memcpy failed");
return true;
}
void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
MS_EXCEPTION_IF_NULL(dst);
MS_EXCEPTION_IF_NULL(src);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
bool GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda memcpy failed");
return true;
}
void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t,
bool GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t,
size_t embedding_size, size_t swap_out_size) {
MS_EXCEPTION_IF_NULL(hash_table_addr);
MS_EXCEPTION_IF_NULL(swap_out_value_addr);
MS_EXCEPTION_IF_NULL(swap_out_index_addr);
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_out_value_addr);
MS_ERROR_IF_NULL(swap_out_index_addr);
DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr),
reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size,
reinterpret_cast<cudaStream_t>(stream_));
return true;
}
void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t,
bool GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t,
size_t embedding_size, size_t swap_in_size) {
MS_EXCEPTION_IF_NULL(hash_table_addr);
MS_EXCEPTION_IF_NULL(swap_in_value_addr);
MS_EXCEPTION_IF_NULL(swap_in_index_addr);
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_in_value_addr);
MS_ERROR_IF_NULL(swap_in_index_addr);
DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr),
reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size,
reinterpret_cast<cudaStream_t>(stream_));
return true;
}
} // namespace gpu
} // namespace ps

@ -28,16 +28,16 @@ class GPUPsCache : public PsCacheBasic {
public:
GPUPsCache() = default;
~GPUPsCache() override = default;
void InitDevice(uint32_t device_id, const void *context) override;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
void RecordEvent() override;
void SynchronizeEvent() override;
void SynchronizeStream() override;
void CopyHostMemToDevice(void *dst, void *src, size_t size) override;
void CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_out_size) override;
void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_in_size) override;
private:

@ -21,21 +21,28 @@
namespace mindspore {
namespace ps {
#define RETURN_IF_FALSE(condition) \
do { \
if (!(condition)) { \
return false; \
} \
} while (false)
class PsCacheBasic {
public:
PsCacheBasic() = default;
virtual ~PsCacheBasic() = default;
virtual void InitDevice(uint32_t device_id, const void *context) = 0;
virtual bool InitDevice(uint32_t device_id, const void *context) = 0;
virtual void *MallocMemory(size_t size) = 0;
virtual void MallocConstantMemory(size_t constant_value) {}
virtual void RecordEvent() = 0;
virtual void SynchronizeEvent() = 0;
virtual void SynchronizeStream() = 0;
virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
virtual bool MallocConstantMemory(size_t constant_value) { return true; }
virtual bool RecordEvent() = 0;
virtual bool SynchronizeEvent() = 0;
virtual bool SynchronizeStream() = 0;
virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0;
virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0;
protected:

File diff suppressed because it is too large Load Diff

@ -126,6 +126,8 @@ class PsCacheManager {
bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name);
bool terminated() const { return terminated_; }
void Finalize();
void DumpHashTables(bool dump_device_tables = false) const;
private:
@ -133,7 +135,7 @@ class PsCacheManager {
~PsCacheManager() = default;
PsCacheManager(const PsCacheManager &) = delete;
PsCacheManager &operator=(const PsCacheManager &) = delete;
void IncreaseStep();
bool IncreaseStep();
void set_current_graph_step() { graph_running_step_ = graph_step_; }
std::string channel_name();
void set_channel_name(const std::string channel_name);
@ -141,23 +143,23 @@ class PsCacheManager {
void AllocMemForHashTable();
void SetLocalIdRank();
void ProcessDataTask(uint32_t device_id, void *context);
void ProcessData();
void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
void WaitGraphRun();
bool ProcessData();
bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
bool WaitGraphRun();
int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device);
void ParseHostDataHostToDevice(size_t id);
void ParseHostDataDeviceToHost(size_t id);
void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
void HashSwapHostToDevice(const HashTableInfo &hash_info);
void HashSwapDeviceToHost(const HashTableInfo &hash_info);
void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info);
void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info);
void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data,
bool ParseHostDataHostToDevice(size_t id);
bool ParseHostDataDeviceToHost(size_t id);
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
bool HashSwapHostToDevice(const HashTableInfo &hash_info);
bool HashSwapDeviceToHost(const HashTableInfo &hash_info);
bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info);
bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info);
bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data,
float *hash_table_addr);
void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr);
void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
bool UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
const int *indices_addr, float *output_addr);
bool CheckFinishInsertInitInfo() const;
@ -172,6 +174,7 @@ class PsCacheManager {
std::mutex data_mutex_;
std::condition_variable data_prase_;
std::condition_variable insert_init_info_;
std::thread process_data_thread_;
std::map<std::string, HashTableInfo> hash_tables_;
std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
@ -185,6 +188,8 @@ class PsCacheManager {
std::pair<size_t, size_t> range_bound_;
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};
std::atomic_bool terminated_{false};
};
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();

@ -28,11 +28,9 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s
if (iter != ps_data_channel_map_.end()) {
MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
auto channel = iter->second;
MS_EXCEPTION_IF_NULL(channel);
channel->set_step_num(step_num);
} else {
auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
MS_EXCEPTION_IF_NULL(channel);
(void)ps_data_channel_map_.emplace(channel_name, channel);
}
}
@ -40,71 +38,95 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s
std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
auto iter = ps_data_channel_map_.find(channel_name);
if (iter == ps_data_channel_map_.end()) {
MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name;
MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name;
return nullptr;
}
return iter->second;
}
void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
if (cache_enable_ == false) {
return;
return true;
}
if (data == nullptr) {
MS_LOG(WARNING) << "No data prefetch.";
return;
return true;
}
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
MS_ERROR_IF_NULL(channel);
channel->set_data(data, data_size);
std::unique_lock<std::mutex> locker(data_mutex_);
data_ready_ = true;
data_process_.notify_one();
if (!need_wait_) {
return true;
}
for (int i = 0; i < 10; i++) {
if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) {
return;
if (data_prefetch_.wait_for(locker, std::chrono::seconds(30),
[this] { return data_ready_ == false || need_wait_ == false; })) {
return true;
} else {
MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)";
}
}
MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size.";
MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size.";
return false;
}
void PsDataPrefetch::FinalizeData(const std::string &channel_name) {
bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
if (cache_enable_ == false) {
return;
return true;
}
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
MS_ERROR_IF_NULL(channel);
channel->ResetData();
std::unique_lock<std::mutex> locker(data_mutex_);
data_ready_ = false;
data_prefetch_.notify_one();
if (!need_wait_) {
return true;
}
for (int i = 0; i < 10; i++) {
if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) {
return;
if (data_process_.wait_for(locker, std::chrono::seconds(30),
[this] { return data_ready_ == true || need_wait_ == false; })) {
return true;
} else {
MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)";
}
}
MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout.";
MS_LOG(ERROR) << "Ps cache data prefetch timeout.";
return false;
}
void *PsDataPrefetch::data(const std::string &channel_name) const {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
if (channel == nullptr) {
return nullptr;
}
return channel->data();
}
size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
if (channel == nullptr) {
return 0;
}
return channel->data_size();
}
void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
void PsDataPrefetch::NotifyFinalize() {
need_wait_ = false;
data_prefetch_.notify_one();
data_process_.notify_one();
}
bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
if (channel == nullptr) {
return false;
}
channel->TryWakeChannel();
return true;
}
} // namespace ps
} // namespace mindspore

@ -19,6 +19,7 @@
#include <map>
#include <string>
#include <memory>
#include <atomic>
#include <condition_variable>
#include "ps/ps_cache/ps_data/ps_data_channel.h"
@ -36,11 +37,12 @@ class EXPORT PsDataPrefetch {
EXPORT bool cache_enable() const { return cache_enable_; }
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
EXPORT void FinalizeData(const std::string &channel_name);
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const;
EXPORT size_t data_size(const std::string &channel_name) const;
EXPORT void TryWakeChannel(const std::string &channel_name);
EXPORT bool TryWakeChannel(const std::string &channel_name);
private:
PsDataPrefetch() : cache_enable_(false), data_ready_(false) {}
@ -54,6 +56,7 @@ class EXPORT PsDataPrefetch {
std::mutex data_mutex_;
std::condition_variable data_prefetch_;
std::condition_variable data_process_;
std::atomic_bool need_wait_{true};
};
} // namespace ps
} // namespace mindspore

@ -17,10 +17,10 @@
#include "ps/ps_context.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "backend/kernel_compiler/kernel.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif
namespace mindspore {
@ -62,7 +62,12 @@ void PSContext::Reset() {
is_worker_ = false;
is_pserver_ = false;
is_sched_ = false;
set_cache_enable(false);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps_cache_instance.Finalize();
set_cache_enable(false);
}
#endif
}
std::string PSContext::ms_role() const {

@ -62,6 +62,16 @@ namespace gpu {
} \
}
#define CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(expression, message) \
{ \
cudaError_t status = (expression); \
if (status != cudaSuccess) { \
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \
<< cudaGetErrorString(status); \
return false; \
} \
}
#define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \
{ \
cudaError_t status = (expression); \

@ -199,6 +199,14 @@ class LogWriter {
} \
} while (0)
#define MS_ERROR_IF_NULL(ptr) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
return false; \
} \
} while (0)
#ifdef DEBUG
#include <cassert>
#define MS_ASSERT(f) assert(f)

Loading…
Cancel
Save