add ps cache init info

pull/9816/head
limingqi107 4 years ago
parent 42b01afc19
commit 31dd182a49

@ -24,7 +24,6 @@
#include "minddata/dataset/engine/datasetops/pipeline_op.h" #include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
#include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/queue.h"
@ -34,6 +33,7 @@
#ifdef ENABLE_GPUQUE #ifdef ENABLE_GPUQUE
#include "minddata/dataset/util/circular_pool.h" #include "minddata/dataset/util/circular_pool.h"
#include "runtime/device/gpu/gpu_buffer_mgr.h" #include "runtime/device/gpu/gpu_buffer_mgr.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
using mindspore::device::BlockQueueStatus_T; using mindspore::device::BlockQueueStatus_T;
using mindspore::device::GpuBufferMgr; using mindspore::device::GpuBufferMgr;
#endif #endif

@ -17,7 +17,9 @@
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/util/log_adapter.h" #include "minddata/dataset/util/log_adapter.h"
#if ENABLE_D
#include "ps/ps_cache/ps_data/ps_data_prefetch.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -50,10 +52,12 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
if (profiling) { if (profiling) {
start_time = ProfilingTime::GetCurMilliSecond(); start_time = ProfilingTime::GetCurMilliSecond();
} }
#if ENABLE_D
// Data prefetch only when PS mode enables cache. // Data prefetch only when PS mode enables cache.
if (items.size() > 0) { if (items.size() > 0) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_);
} }
#endif
if (tdt::TdtHostPushData(channel_name, items) != 0) { if (tdt::TdtHostPushData(channel_name, items) != 0) {
return FAILED; return FAILED;
} }

@ -73,6 +73,8 @@ void PsCacheManager::InsertWeightInitInfo(const std::string &param_name, size_t
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
return; return;
} }
MS_LOG(INFO) << "Insert embedding table init info:" << param_name << ", global seed:" << global_seed
<< ", op seed:" << op_seed;
hash_table_info.param_init_info_.param_type_ = kWeight; hash_table_info.param_init_info_.param_type_ = kWeight;
hash_table_info.param_init_info_.global_seed_ = global_seed; hash_table_info.param_init_info_.global_seed_ = global_seed;
hash_table_info.param_init_info_.op_seed_ = op_seed; hash_table_info.param_init_info_.op_seed_ = op_seed;
@ -91,6 +93,7 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string &param_name, float i
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
return; return;
} }
MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val;
hash_table_info.param_init_info_.param_type_ = kAccumulation; hash_table_info.param_init_info_.param_type_ = kAccumulation;
hash_table_info.param_init_info_.init_val_ = init_val; hash_table_info.param_init_info_.init_val_ = init_val;
if (CheckFinishInsertInitInfo()) { if (CheckFinishInsertInitInfo()) {
@ -107,6 +110,7 @@ bool PsCacheManager::CheckFinishInsertInitInfo() const {
return false; return false;
} }
} }
MS_LOG(INFO) << "Finish inserting embedding table init info.";
return true; return true;
} }
@ -141,6 +145,7 @@ void PsCacheManager::Initialize() {
AddEmbeddingTable(); AddEmbeddingTable();
AllocMemForHashTable(); AllocMemForHashTable();
SetLocalIdRank(); SetLocalIdRank();
DumpHashTables();
initialized_ps_cache_ = true; initialized_ps_cache_ = true;
} }
@ -155,6 +160,7 @@ void PsCacheManager::AddEmbeddingTable() const {
} }
void PsCacheManager::InitParameterServer() { void PsCacheManager::InitParameterServer() {
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_;
std::unique_lock<std::mutex> locker(data_mutex_); std::unique_lock<std::mutex> locker(data_mutex_);
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; });
@ -181,6 +187,7 @@ void PsCacheManager::InitParameterServer() {
finish_init_parameter_server_ = true; finish_init_parameter_server_ = true;
data_prase_.notify_one(); data_prase_.notify_one();
MS_LOG(INFO) << "Embedding table init end.";
} }
void PsCacheManager::AllocMemForHashTable() { void PsCacheManager::AllocMemForHashTable() {
@ -237,10 +244,14 @@ void PsCacheManager::set_channel_name(const std::string channel_name) {
void PsCacheManager::IncreaseStep() { void PsCacheManager::IncreaseStep() {
if (data_step_ >= UINT64_MAX) { if (data_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t.";
} }
data_step_++; data_step_++;
set_current_graph_step(); set_current_graph_step();
if (graph_running_step_ > data_step_) {
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step ("
<< data_step_ << ").";
}
} }
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
@ -248,8 +259,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t.";
} }
if (graph_step_ == 0) { if (graph_step_ == 0) {
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
std::unique_lock<std::mutex> locker(data_mutex_); std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; });
MS_LOG(INFO) << "Graph running waiting embedding table init end.";
} }
graph_step_++; graph_step_++;
set_channel_name(channel_name); set_channel_name(channel_name);
@ -755,29 +768,35 @@ void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
} }
void PsCacheManager::DumpHashTables() const { void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
MS_EXCEPTION_IF_NULL(embedding_device_cache_); MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
for (const auto &item : hash_tables_) { for (const auto &item : hash_tables_) {
const auto &param_name = item.first; const auto &param_name = item.first;
size_t cache_vocab_size = item.second.cache_vocab_size; size_t cache_vocab_size = item.second.cache_vocab_size;
size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
size_t embedding_size = item.second.embedding_size; size_t embedding_size = item.second.embedding_size;
size_t vocab_size = item.second.vocab_size; size_t vocab_size = item.second.vocab_size;
MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size MS_LOG(INFO) << "Hash table info:"
<< " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr) << " embedding table name:" << param_name << ", vocab size:" << vocab_size
<< " || " << reinterpret_cast<void *>(item.second.host_address.get()); << ", embedding size:" << embedding_size << ", device cache size:" << cache_vocab_size
float *output = new float[item.second.device_address.size / 4]; << ", host cache size:" << host_cache_vocab_size
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, << ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr)
item.second.device_address.size); << ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get());
embedding_device_cache_->cache_->SynchronizeStream(); if (dump_device_tables) {
for (size_t i = 0; i < cache_vocab_size; i++) { float *output = new float[item.second.device_address.size / 4];
for (size_t j = 0; j < embedding_size; j++) { embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr,
std::cout << output[i * embedding_size + j] << " "; item.second.device_address.size);
embedding_device_cache_->cache_->SynchronizeStream();
for (size_t i = 0; i < cache_vocab_size; i++) {
for (size_t j = 0; j < embedding_size; j++) {
std::cout << output[i * embedding_size + j] << " ";
}
std::cout << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
delete[] output;
} }
std::cout << std::endl;
delete[] output;
} }
} }
} // namespace ps } // namespace ps

@ -125,7 +125,7 @@ class PsCacheManager {
bool initialized_ps_cache() const { return initialized_ps_cache_; } bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context); void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name); void IncreaseGraphStep(const std::string &channel_name);
void DumpHashTables() const; void DumpHashTables(bool dump_device_tables = false) const;
private: private:
PsCacheManager() = default; PsCacheManager() = default;

@ -14,6 +14,8 @@
# ============================================================================ # ============================================================================
"""embedding""" """embedding"""
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore import log as logger
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -195,6 +197,11 @@ class EmbeddingLookup(Cell):
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
if not sparse and target == 'CPU': if not sparse and target == 'CPU':
raise ValueError('When target is CPU, embedding_lookup must be sparse.') raise ValueError('When target is CPU, embedding_lookup must be sparse.')
enable_ps = context.get_ps_context("enable_ps")
if not enable_ps and vocab_cache_size > 0:
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, "
"current mode is not parameter server trainning mode, so it will be ignored.")
vocab_cache_size = 0
if sparse: if sparse:
self.gatherv2 = P.SparseGatherV2() self.gatherv2 = P.SparseGatherV2()
else: else:

Loading…
Cancel
Save