From 31dd182a49e8aaab6101d84babeae63940fb67ea Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Fri, 11 Dec 2020 10:12:02 +0800 Subject: [PATCH] add ps cache init info --- .../engine/datasetops/device_queue_op.h | 2 +- .../minddata/dataset/engine/tdt/tdt_plugin.cc | 4 ++ .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 47 +++++++++++++------ .../ccsrc/ps/ps_cache/ps_cache_manager.h | 2 +- mindspore/nn/layer/embedding.py | 7 +++ 5 files changed, 46 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h index 7804e123a3..bf8dfc9037 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -24,7 +24,6 @@ #include "minddata/dataset/engine/datasetops/pipeline_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/util/status.h" -#include "ps/ps_cache/ps_data/ps_data_prefetch.h" #ifdef ENABLE_TDTQUE #include "minddata/dataset/util/queue.h" @@ -34,6 +33,7 @@ #ifdef ENABLE_GPUQUE #include "minddata/dataset/util/circular_pool.h" #include "runtime/device/gpu/gpu_buffer_mgr.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" using mindspore::device::BlockQueueStatus_T; using mindspore::device::GpuBufferMgr; #endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc index 78c8964e86..8760e8c69b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -17,7 +17,9 @@ #include "utils/ms_utils.h" #include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/util/log_adapter.h" +#if ENABLE_D #include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#endif namespace mindspore { namespace dataset { @@ -50,10 +52,12 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe if (profiling) { start_time = ProfilingTime::GetCurMilliSecond(); } +#if ENABLE_D // Data prefetch only when PS mode enables cache. if (items.size() > 0) { ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); } +#endif if (tdt::TdtHostPushData(channel_name, items) != 0) { return FAILED; } diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index f70539485e..0d725ddaf8 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -73,6 +73,8 @@ void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { return; } + MS_LOG(INFO) << "Insert embedding table init info:" << param_name << ", global seed:" << global_seed + << ", op seed:" << op_seed; hash_table_info.param_init_info_.param_type_ = kWeight; hash_table_info.param_init_info_.global_seed_ = global_seed; hash_table_info.param_init_info_.op_seed_ = op_seed; @@ -91,6 +93,7 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float i if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { return; } + MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val; hash_table_info.param_init_info_.param_type_ = kAccumulation; hash_table_info.param_init_info_.init_val_ = init_val; if (CheckFinishInsertInitInfo()) { @@ -107,6 +110,7 @@ bool PsCacheManager::CheckFinishInsertInitInfo() const { return false; } } + MS_LOG(INFO) << "Finish inserting embedding table init info."; return true; } @@ -141,6 +145,7 @@ void PsCacheManager::Initialize() { AddEmbeddingTable(); AllocMemForHashTable(); SetLocalIdRank(); + DumpHashTables(); initialized_ps_cache_ = true; } @@ -155,6 +160,7 @@ void PsCacheManager::AddEmbeddingTable() const { } void PsCacheManager::InitParameterServer() { + MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; std::unique_lock locker(data_mutex_); insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); @@ -181,6 +187,7 @@ void PsCacheManager::InitParameterServer() { finish_init_parameter_server_ = true; data_prase_.notify_one(); + MS_LOG(INFO) << "Embedding table init end."; } void PsCacheManager::AllocMemForHashTable() { @@ -237,10 +244,14 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { void PsCacheManager::IncreaseStep() { if (data_step_ >= UINT64_MAX) { - MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; + MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; } data_step_++; set_current_graph_step(); + if (graph_running_step_ > data_step_) { + MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step (" + << data_step_ << ")."; + } } void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { @@ -248,8 +259,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; } if (graph_step_ == 0) { + MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; std::unique_lock locker(data_mutex_); data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); + MS_LOG(INFO) << "Graph running waiting embedding table init end."; } graph_step_++; set_channel_name(channel_name); @@ -755,29 +768,35 @@ void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_da worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); } -void PsCacheManager::DumpHashTables() const { +void PsCacheManager::DumpHashTables(bool dump_device_tables) const { MS_EXCEPTION_IF_NULL(embedding_device_cache_); MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; size_t cache_vocab_size = item.second.cache_vocab_size; + size_t host_cache_vocab_size = item.second.host_cache_vocab_size; size_t embedding_size = item.second.embedding_size; size_t vocab_size = item.second.vocab_size; - MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size - << " || " << vocab_size << " || " << reinterpret_cast(item.second.device_address.addr) - << " || " << reinterpret_cast(item.second.host_address.get()); - float *output = new float[item.second.device_address.size / 4]; - embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, - item.second.device_address.size); - embedding_device_cache_->cache_->SynchronizeStream(); - for (size_t i = 0; i < cache_vocab_size; i++) { - for (size_t j = 0; j < embedding_size; j++) { - std::cout << output[i * embedding_size + j] << " "; + MS_LOG(INFO) << "Hash table info:" + << " embedding table name:" << param_name << ", vocab size:" << vocab_size + << ", embedding size:" << embedding_size << ", device cache size:" << cache_vocab_size + << ", host cache size:" << host_cache_vocab_size + << ", device cache address:" << reinterpret_cast(item.second.device_address.addr) + << ", host cache address:" << reinterpret_cast(item.second.host_address.get()); + if (dump_device_tables) { + float *output = new float[item.second.device_address.size / 4]; + embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, + item.second.device_address.size); + embedding_device_cache_->cache_->SynchronizeStream(); + for (size_t i = 0; i < cache_vocab_size; i++) { + for (size_t j = 0; j < embedding_size; j++) { + std::cout << output[i * embedding_size + j] << " "; + } + std::cout << std::endl; } std::cout << std::endl; + delete[] output; } - std::cout << std::endl; - delete[] output; } } } // namespace ps diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index a97459511e..07c7949d7b 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -125,7 +125,7 @@ class PsCacheManager { bool initialized_ps_cache() const { return initialized_ps_cache_; } void DoProcessData(uint32_t device_id, void *context); void IncreaseGraphStep(const std::string &channel_name); - void DumpHashTables() const; + void DumpHashTables(bool dump_device_tables = false) const; private: PsCacheManager() = default; diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 3f36ea5c18..69cd292c8a 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -14,6 +14,8 @@ # ============================================================================ """embedding""" import mindspore.common.dtype as mstype +import mindspore.context as context +from mindspore import log as logger from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -195,6 +197,11 @@ class EmbeddingLookup(Cell): + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') if not sparse and target == 'CPU': raise ValueError('When target is CPU, embedding_lookup must be sparse.') + enable_ps = context.get_ps_context("enable_ps") + if not enable_ps and vocab_cache_size > 0: + logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, " + "current mode is not parameter server trainning mode, so it will be ignored.") + vocab_cache_size = 0 if sparse: self.gatherv2 = P.SparseGatherV2() else: