|
|
|
@ -73,6 +73,8 @@ void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t
|
|
|
|
|
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Insert embedding table init info:" << param_name << ", global seed:" << global_seed
|
|
|
|
|
<< ", op seed:" << op_seed;
|
|
|
|
|
hash_table_info.param_init_info_.param_type_ = kWeight;
|
|
|
|
|
hash_table_info.param_init_info_.global_seed_ = global_seed;
|
|
|
|
|
hash_table_info.param_init_info_.op_seed_ = op_seed;
|
|
|
|
@ -91,6 +93,7 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float i
|
|
|
|
|
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val;
|
|
|
|
|
hash_table_info.param_init_info_.param_type_ = kAccumulation;
|
|
|
|
|
hash_table_info.param_init_info_.init_val_ = init_val;
|
|
|
|
|
if (CheckFinishInsertInitInfo()) {
|
|
|
|
@ -107,6 +110,7 @@ bool PsCacheManager::CheckFinishInsertInitInfo() const {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Finish inserting embedding table init info.";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -141,6 +145,7 @@ void PsCacheManager::Initialize() {
|
|
|
|
|
AddEmbeddingTable();
|
|
|
|
|
AllocMemForHashTable();
|
|
|
|
|
SetLocalIdRank();
|
|
|
|
|
DumpHashTables();
|
|
|
|
|
initialized_ps_cache_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -155,6 +160,7 @@ void PsCacheManager::AddEmbeddingTable() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::InitParameterServer() {
|
|
|
|
|
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_;
|
|
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_);
|
|
|
|
|
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; });
|
|
|
|
|
|
|
|
|
@ -181,6 +187,7 @@ void PsCacheManager::InitParameterServer() {
|
|
|
|
|
|
|
|
|
|
finish_init_parameter_server_ = true;
|
|
|
|
|
data_prase_.notify_one();
|
|
|
|
|
MS_LOG(INFO) << "Embedding table init end.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::AllocMemForHashTable() {
|
|
|
|
@ -237,10 +244,14 @@ void PsCacheManager::set_channel_name(const std::string channel_name) {
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::IncreaseStep() {
|
|
|
|
|
if (data_step_ >= UINT64_MAX) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t.";
|
|
|
|
|
}
|
|
|
|
|
data_step_++;
|
|
|
|
|
set_current_graph_step();
|
|
|
|
|
if (graph_running_step_ > data_step_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step ("
|
|
|
|
|
<< data_step_ << ").";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
|
|
|
|
@ -248,8 +259,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t.";
|
|
|
|
|
}
|
|
|
|
|
if (graph_step_ == 0) {
|
|
|
|
|
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
|
|
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_);
|
|
|
|
|
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; });
|
|
|
|
|
MS_LOG(INFO) << "Graph running waiting embedding table init end.";
|
|
|
|
|
}
|
|
|
|
|
graph_step_++;
|
|
|
|
|
set_channel_name(channel_name);
|
|
|
|
@ -755,29 +768,35 @@ void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
|
|
|
|
|
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PsCacheManager::DumpHashTables() const {
|
|
|
|
|
void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
|
|
|
|
for (const auto &item : hash_tables_) {
|
|
|
|
|
const auto ¶m_name = item.first;
|
|
|
|
|
size_t cache_vocab_size = item.second.cache_vocab_size;
|
|
|
|
|
size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
|
|
|
|
|
size_t embedding_size = item.second.embedding_size;
|
|
|
|
|
size_t vocab_size = item.second.vocab_size;
|
|
|
|
|
MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size
|
|
|
|
|
<< " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr)
|
|
|
|
|
<< " || " << reinterpret_cast<void *>(item.second.host_address.get());
|
|
|
|
|
float *output = new float[item.second.device_address.size / 4];
|
|
|
|
|
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr,
|
|
|
|
|
item.second.device_address.size);
|
|
|
|
|
embedding_device_cache_->cache_->SynchronizeStream();
|
|
|
|
|
for (size_t i = 0; i < cache_vocab_size; i++) {
|
|
|
|
|
for (size_t j = 0; j < embedding_size; j++) {
|
|
|
|
|
std::cout << output[i * embedding_size + j] << " ";
|
|
|
|
|
MS_LOG(INFO) << "Hash table info:"
|
|
|
|
|
<< " embedding table name:" << param_name << ", vocab size:" << vocab_size
|
|
|
|
|
<< ", embedding size:" << embedding_size << ", device cache size:" << cache_vocab_size
|
|
|
|
|
<< ", host cache size:" << host_cache_vocab_size
|
|
|
|
|
<< ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr)
|
|
|
|
|
<< ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get());
|
|
|
|
|
if (dump_device_tables) {
|
|
|
|
|
float *output = new float[item.second.device_address.size / 4];
|
|
|
|
|
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr,
|
|
|
|
|
item.second.device_address.size);
|
|
|
|
|
embedding_device_cache_->cache_->SynchronizeStream();
|
|
|
|
|
for (size_t i = 0; i < cache_vocab_size; i++) {
|
|
|
|
|
for (size_t j = 0; j < embedding_size; j++) {
|
|
|
|
|
std::cout << output[i * embedding_size + j] << " ";
|
|
|
|
|
}
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
}
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
delete[] output;
|
|
|
|
|
}
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
delete[] output;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace ps
|
|
|
|
|