add ps cache init info

pull/9816/head
limingqi107 4 years ago
parent 42b01afc19
commit 31dd182a49

@ -24,7 +24,6 @@
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#ifdef ENABLE_TDTQUE
#include "minddata/dataset/util/queue.h"
@ -34,6 +33,7 @@
#ifdef ENABLE_GPUQUE
#include "minddata/dataset/util/circular_pool.h"
#include "runtime/device/gpu/gpu_buffer_mgr.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
using mindspore::device::BlockQueueStatus_T;
using mindspore::device::GpuBufferMgr;
#endif

@ -17,7 +17,9 @@
#include "utils/ms_utils.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/util/log_adapter.h"
#if ENABLE_D
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif
namespace mindspore {
namespace dataset {
@ -50,10 +52,12 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
if (profiling) {
start_time = ProfilingTime::GetCurMilliSecond();
}
#if ENABLE_D
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_);
}
#endif
if (tdt::TdtHostPushData(channel_name, items) != 0) {
return FAILED;
}

@ -73,6 +73,8 @@ void PsCacheManager::InsertWeightInitInfo(const std::string &param_name, size_t
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
return;
}
MS_LOG(INFO) << "Insert embedding table init info:" << param_name << ", global seed:" << global_seed
<< ", op seed:" << op_seed;
hash_table_info.param_init_info_.param_type_ = kWeight;
hash_table_info.param_init_info_.global_seed_ = global_seed;
hash_table_info.param_init_info_.op_seed_ = op_seed;
@ -91,6 +93,7 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string &param_name, float i
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
return;
}
MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val;
hash_table_info.param_init_info_.param_type_ = kAccumulation;
hash_table_info.param_init_info_.init_val_ = init_val;
if (CheckFinishInsertInitInfo()) {
@ -107,6 +110,7 @@ bool PsCacheManager::CheckFinishInsertInitInfo() const {
return false;
}
}
MS_LOG(INFO) << "Finish inserting embedding table init info.";
return true;
}
@ -141,6 +145,7 @@ void PsCacheManager::Initialize() {
AddEmbeddingTable();
AllocMemForHashTable();
SetLocalIdRank();
DumpHashTables();
initialized_ps_cache_ = true;
}
@ -155,6 +160,7 @@ void PsCacheManager::AddEmbeddingTable() const {
}
void PsCacheManager::InitParameterServer() {
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_;
std::unique_lock<std::mutex> locker(data_mutex_);
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; });
@ -181,6 +187,7 @@ void PsCacheManager::InitParameterServer() {
finish_init_parameter_server_ = true;
data_prase_.notify_one();
MS_LOG(INFO) << "Embedding table init end.";
}
void PsCacheManager::AllocMemForHashTable() {
@ -237,10 +244,14 @@ void PsCacheManager::set_channel_name(const std::string channel_name) {
void PsCacheManager::IncreaseStep() {
if (data_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t.";
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t.";
}
data_step_++;
set_current_graph_step();
if (graph_running_step_ > data_step_) {
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step ("
<< data_step_ << ").";
}
}
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
@ -248,8 +259,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t.";
}
if (graph_step_ == 0) {
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; });
MS_LOG(INFO) << "Graph running waiting embedding table init end.";
}
graph_step_++;
set_channel_name(channel_name);
@ -755,29 +768,35 @@ void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
}
void PsCacheManager::DumpHashTables() const {
void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t cache_vocab_size = item.second.cache_vocab_size;
size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
size_t embedding_size = item.second.embedding_size;
size_t vocab_size = item.second.vocab_size;
MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size
<< " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr)
<< " || " << reinterpret_cast<void *>(item.second.host_address.get());
float *output = new float[item.second.device_address.size / 4];
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr,
item.second.device_address.size);
embedding_device_cache_->cache_->SynchronizeStream();
for (size_t i = 0; i < cache_vocab_size; i++) {
for (size_t j = 0; j < embedding_size; j++) {
std::cout << output[i * embedding_size + j] << " ";
MS_LOG(INFO) << "Hash table info:"
<< " embedding table name:" << param_name << ", vocab size:" << vocab_size
<< ", embedding size:" << embedding_size << ", device cache size:" << cache_vocab_size
<< ", host cache size:" << host_cache_vocab_size
<< ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr)
<< ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get());
if (dump_device_tables) {
float *output = new float[item.second.device_address.size / 4];
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr,
item.second.device_address.size);
embedding_device_cache_->cache_->SynchronizeStream();
for (size_t i = 0; i < cache_vocab_size; i++) {
for (size_t j = 0; j < embedding_size; j++) {
std::cout << output[i * embedding_size + j] << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
delete[] output;
}
std::cout << std::endl;
delete[] output;
}
}
} // namespace ps

@ -125,7 +125,7 @@ class PsCacheManager {
bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name);
void DumpHashTables() const;
void DumpHashTables(bool dump_device_tables = false) const;
private:
PsCacheManager() = default;

@ -14,6 +14,8 @@
# ============================================================================
"""embedding"""
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
@ -195,6 +197,11 @@ class EmbeddingLookup(Cell):
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
if not sparse and target == 'CPU':
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
enable_ps = context.get_ps_context("enable_ps")
if not enable_ps and vocab_cache_size > 0:
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, "
"current mode is not parameter server trainning mode, so it will be ignored.")
vocab_cache_size = 0
if sparse:
self.gatherv2 = P.SparseGatherV2()
else:

Loading…
Cancel
Save