|
|
|
@ -51,6 +51,8 @@
|
|
|
|
|
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
|
|
|
|
|
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
|
|
|
|
|
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
|
|
|
|
|
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
|
|
|
|
#include "ps/random_normal/random_normal.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace ps {
|
|
|
|
@ -100,6 +102,7 @@ class ParameterServer {
|
|
|
|
|
void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
|
|
|
|
|
ParameterServer *ps_;
|
|
|
|
@ -118,13 +121,15 @@ class ParameterServer {
|
|
|
|
|
void InitWeight(const Key &key, const WeightPtr &weight);
|
|
|
|
|
void InitGrad(const Key &key, const GradPtr &grad);
|
|
|
|
|
void InitEmbeddingTable(const Key &key,
|
|
|
|
|
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
|
|
|
|
|
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
|
|
|
|
|
const ParamInitInfo ¶m_init_info);
|
|
|
|
|
bool HasWeight(const Key &key);
|
|
|
|
|
void Finalize();
|
|
|
|
|
void UpdateWeights();
|
|
|
|
|
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
|
|
|
|
|
WeightPtr weight(const Key &key);
|
|
|
|
|
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
|
|
|
|
|
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
|
|
|
|
|
bool ReadyForUpdateWeights();
|
|
|
|
|
bool ReadyForPush(const Key &key);
|
|
|
|
|
bool ReadyForPull(const Key &key);
|
|
|
|
@ -193,6 +198,7 @@ void ParameterServer<T>::ServerHandler::Init() {
|
|
|
|
|
handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
|
|
|
|
|
handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
|
|
|
|
|
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
|
|
|
|
|
handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings;
|
|
|
|
|
handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -302,7 +308,17 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
|
|
|
|
|
for (int64_t k = 0; k < lens[2]; k++) {
|
|
|
|
|
output_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
|
|
|
|
|
}
|
|
|
|
|
ps_->InitEmbeddingTable(key, shapes);
|
|
|
|
|
ParamInitInfo param_init_info;
|
|
|
|
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
|
|
|
|
param_init_info.param_type_ = static_cast<ParamType>(lens[3]);
|
|
|
|
|
if (param_init_info.param_type_ == kWeight) {
|
|
|
|
|
param_init_info.global_seed_ = static_cast<size_t>(lens[4]);
|
|
|
|
|
param_init_info.op_seed_ = static_cast<size_t>(lens[5]);
|
|
|
|
|
} else if (param_init_info.param_type_ == kAccumulation) {
|
|
|
|
|
param_init_info.init_val_ = req_data.vals[index];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ps_->InitEmbeddingTable(key, shapes, param_init_info);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -338,6 +354,18 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
|
|
|
|
|
ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta,
|
|
|
|
|
const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(res);
|
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size());
|
|
|
|
|
const Values &update_vals = req_data.vals;
|
|
|
|
|
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
@ -476,7 +504,8 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
|
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
|
|
|
|
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
|
|
|
|
|
const ParamInitInfo ¶m_init_info) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shapes);
|
|
|
|
|
if (weights_.count(key) == 0) {
|
|
|
|
|
std::shared_ptr<PServerKernel> lookup =
|
|
|
|
@ -493,9 +522,19 @@ void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
|
T *embedding_data = embedding->data();
|
|
|
|
|
std::default_random_engine engine;
|
|
|
|
|
std::normal_distribution<float> random(0, 0.01);
|
|
|
|
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
|
|
|
|
if (param_init_info.param_type_ == kWeight) {
|
|
|
|
|
InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data);
|
|
|
|
|
} else if (param_init_info.param_type_ == kAccumulation) {
|
|
|
|
|
for (size_t i = 0; i < total_dims; i++) {
|
|
|
|
|
embedding_data[i] = param_init_info.init_val_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t i = 0; i < total_dims; i++) {
|
|
|
|
|
embedding_data[i] = random(engine);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
weights_[key] = embedding;
|
|
|
|
|
tokens_[key] = 0;
|
|
|
|
|
is_embedding_[key] = true;
|
|
|
|
@ -673,6 +712,23 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids,
|
|
|
|
|
res->lens.push_back(res->vals.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) {
|
|
|
|
|
if (weights_.count(key) == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid embedding table key " << key;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (embedding_lookup_ops_.count(key) == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
WeightPtr table_ptr = weights_[key];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(table_ptr);
|
|
|
|
|
std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(table_lookup_op);
|
|
|
|
|
table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline bool ParameterServer<T>::ReadyForUpdateWeights() {
|
|
|
|
|
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
|
|
|
|
|