|
|
@ -91,6 +91,8 @@ class ParameterServer {
|
|
|
|
::ps::KVPairs<T> *res);
|
|
|
|
::ps::KVPairs<T> *res);
|
|
|
|
void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
|
|
|
void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
|
|
|
void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
|
|
|
|
|
|
@ -98,6 +100,9 @@ class ParameterServer {
|
|
|
|
typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
::ps::KVPairs<T> *res);
|
|
|
|
::ps::KVPairs<T> *res);
|
|
|
|
std::unordered_map<int, RequestHandler> handlers_;
|
|
|
|
std::unordered_map<int, RequestHandler> handlers_;
|
|
|
|
|
|
|
|
std::unordered_map<Key, bool> init_weights_;
|
|
|
|
|
|
|
|
std::unordered_map<Key, bool> init_weight_to_optim_;
|
|
|
|
|
|
|
|
std::unordered_map<Key, bool> init_optim_info_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
bool Init(const FuncGraphPtr &func_graph);
|
|
|
|
bool Init(const FuncGraphPtr &func_graph);
|
|
|
@ -115,9 +120,11 @@ class ParameterServer {
|
|
|
|
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
|
|
|
|
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
|
|
|
|
int SumOfShapes(const std::vector<int> &shapes) const;
|
|
|
|
int SumOfShapes(const std::vector<int> &shapes) const;
|
|
|
|
bool ReadyForUpdateWeights();
|
|
|
|
bool ReadyForUpdateWeights();
|
|
|
|
bool ReadyForAccumGrads();
|
|
|
|
bool ReadyForPush(const Key &key);
|
|
|
|
|
|
|
|
bool ReadyForPull(const Key &key);
|
|
|
|
void ResetGradAccumCount();
|
|
|
|
void ResetGradAccumCount();
|
|
|
|
const CNodePtr GetCNode(const std::string &name) const;
|
|
|
|
const CNodePtr GetCNode(const std::string &name) const;
|
|
|
|
|
|
|
|
std::mutex &mutex();
|
|
|
|
|
|
|
|
|
|
|
|
size_t pserver_num_;
|
|
|
|
size_t pserver_num_;
|
|
|
|
size_t worker_num_;
|
|
|
|
size_t worker_num_;
|
|
|
@ -136,13 +143,14 @@ class ParameterServer {
|
|
|
|
std::unordered_map<Key, std::string> weight_key_to_optims_;
|
|
|
|
std::unordered_map<Key, std::string> weight_key_to_optims_;
|
|
|
|
std::unordered_map<Key, std::string> weight_key_to_optim_op_;
|
|
|
|
std::unordered_map<Key, std::string> weight_key_to_optim_op_;
|
|
|
|
std::unordered_map<Key, WeightPtr> weights_;
|
|
|
|
std::unordered_map<Key, WeightPtr> weights_;
|
|
|
|
|
|
|
|
std::unordered_map<Key, bool> is_embedding_;
|
|
|
|
std::unordered_map<Key, WeightPtr> grads_;
|
|
|
|
std::unordered_map<Key, WeightPtr> grads_;
|
|
|
|
std::unordered_map<Key, size_t> grads_accum_counter_;
|
|
|
|
std::unordered_map<Key, size_t> grads_accum_counter_;
|
|
|
|
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
|
|
|
|
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
|
|
|
|
|
|
|
|
std::unordered_map<Key, uint64_t> tokens_;
|
|
|
|
|
|
|
|
|
|
|
|
std::mutex mutex_;
|
|
|
|
std::mutex mutex_;
|
|
|
|
std::condition_variable apply_grads_cv_;
|
|
|
|
std::condition_variable apply_grads_cv_;
|
|
|
|
std::condition_variable accum_grads_cv_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<std::thread> thread_;
|
|
|
|
std::unique_ptr<std::thread> thread_;
|
|
|
|
|
|
|
|
|
|
|
@ -171,6 +179,8 @@ void ParameterServer<T>::ServerHandler::Init() {
|
|
|
|
handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
|
|
|
|
handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
|
|
|
|
handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
|
|
|
|
handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
|
|
|
|
handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
|
|
|
|
handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
|
|
|
|
|
|
|
|
handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
|
|
|
|
|
|
|
|
handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
|
|
|
|
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
|
|
|
|
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
|
|
|
|
handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
|
|
|
|
handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -192,11 +202,17 @@ void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta,
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta,
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
|
|
|
size_t key_num = req_data.keys.size();
|
|
|
|
size_t key_num = req_data.keys.size();
|
|
|
|
T *data_ptr = req_data.vals.data();
|
|
|
|
T *data_ptr = req_data.vals.data();
|
|
|
|
size_t pos = 0;
|
|
|
|
size_t pos = 0;
|
|
|
|
for (size_t i = 0; i < key_num; i++) {
|
|
|
|
for (size_t i = 0; i < key_num; i++) {
|
|
|
|
Key key = req_data.keys[i];
|
|
|
|
Key key = req_data.keys[i];
|
|
|
|
|
|
|
|
if (init_weights_[key]) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
init_weights_[key] = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i];
|
|
|
|
size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i];
|
|
|
|
|
|
|
|
|
|
|
|
WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>();
|
|
|
|
WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>();
|
|
|
@ -213,10 +229,16 @@ template <typename T>
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta,
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta,
|
|
|
|
const ::ps::KVPairs<T> &req_data,
|
|
|
|
const ::ps::KVPairs<T> &req_data,
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
|
|
|
size_t key_num = req_data.keys.size();
|
|
|
|
size_t key_num = req_data.keys.size();
|
|
|
|
for (size_t i = 0; i < key_num; i++) {
|
|
|
|
for (size_t i = 0; i < key_num; i++) {
|
|
|
|
Key key = req_data.keys[i];
|
|
|
|
Key key = req_data.keys[i];
|
|
|
|
T val = req_data.vals[i];
|
|
|
|
T val = req_data.vals[i];
|
|
|
|
|
|
|
|
if (init_weight_to_optim_[key]) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
init_weight_to_optim_[key] = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
ps_->InitWeightKeyToOptims(key, val);
|
|
|
|
ps_->InitWeightKeyToOptims(key, val);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -224,12 +246,26 @@ void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KV
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta,
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta,
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
|
|
|
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
|
|
|
if (init_optim_info_[key]) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
init_optim_info_[key] = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens);
|
|
|
|
ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta,
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta,
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
|
|
|
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
|
|
|
if (init_weights_[key]) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
init_weights_[key] = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
|
|
|
|
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
|
|
|
|
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
|
|
|
|
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
|
|
|
|
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
|
|
|
|
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
|
|
|
@ -239,7 +275,6 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
|
|
|
|
shapes->push_back(indices_shape);
|
|
|
|
shapes->push_back(indices_shape);
|
|
|
|
shapes->push_back(output_shape);
|
|
|
|
shapes->push_back(output_shape);
|
|
|
|
|
|
|
|
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
|
|
|
const Lengths &lens = req_data.lens;
|
|
|
|
const Lengths &lens = req_data.lens;
|
|
|
|
size_t index = 0;
|
|
|
|
size_t index = 0;
|
|
|
|
for (int i = 0; i < lens[0]; i++) {
|
|
|
|
for (int i = 0; i < lens[0]; i++) {
|
|
|
@ -254,6 +289,26 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
|
|
|
|
ps_->InitEmbeddingTable(key, shapes);
|
|
|
|
ps_->InitEmbeddingTable(key, shapes);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleCheckReadyForPush(const ::ps::KVMeta &req_meta,
|
|
|
|
|
|
|
|
const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
|
|
|
bool ready = ps_->ReadyForPush(key);
|
|
|
|
|
|
|
|
res->keys.push_back(key);
|
|
|
|
|
|
|
|
res->vals.push_back(ready);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleCheckReadyForPull(const ::ps::KVMeta &req_meta,
|
|
|
|
|
|
|
|
const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
|
|
|
bool ready = ps_->ReadyForPull(key);
|
|
|
|
|
|
|
|
res->keys.push_back(key);
|
|
|
|
|
|
|
|
res->vals.push_back(ready);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
@ -365,6 +420,8 @@ void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
|
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key;
|
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key;
|
|
|
|
if (weights_.count(key) == 0) {
|
|
|
|
if (weights_.count(key) == 0) {
|
|
|
|
weights_[key] = weight;
|
|
|
|
weights_[key] = weight;
|
|
|
|
|
|
|
|
tokens_[key] = 0;
|
|
|
|
|
|
|
|
is_embedding_[key] = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -399,6 +456,8 @@ void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
embedding_data[i] = random(engine);
|
|
|
|
embedding_data[i] = random(engine);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
weights_[key] = embedding;
|
|
|
|
weights_[key] = embedding;
|
|
|
|
|
|
|
|
tokens_[key] = 0;
|
|
|
|
|
|
|
|
is_embedding_[key] = true;
|
|
|
|
|
|
|
|
|
|
|
|
grads_accum_counter_[key] = 0;
|
|
|
|
grads_accum_counter_[key] = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -439,17 +498,17 @@ void ParameterServer<T>::UpdateWeights() {
|
|
|
|
optim_info->ComputeMean(worker_num_);
|
|
|
|
optim_info->ComputeMean(worker_num_);
|
|
|
|
optimizer->Execute(inputs, workspaces, outputs);
|
|
|
|
optimizer->Execute(inputs, workspaces, outputs);
|
|
|
|
optim_info->Reset();
|
|
|
|
optim_info->Reset();
|
|
|
|
|
|
|
|
if (!is_embedding_[key]) {
|
|
|
|
|
|
|
|
tokens_[key] = worker_num_;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ResetGradAccumCount();
|
|
|
|
ResetGradAccumCount();
|
|
|
|
accum_grads_cv_.notify_all();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
|
|
|
|
void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); });
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const Key &key = keys[0];
|
|
|
|
const Key &key = keys[0];
|
|
|
|
std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
|
|
|
|
std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
|
|
|
|
|
|
|
|
|
|
|
@ -482,14 +541,13 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
WeightPtr ParameterServer<T>::weight(const Key &key) {
|
|
|
|
WeightPtr ParameterServer<T>::weight(const Key &key) {
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
|
|
|
|
|
|
|
|
if (weights_.count(key) == 0) {
|
|
|
|
if (weights_.count(key) == 0) {
|
|
|
|
MS_LOG(ERROR) << "Invalid weight key " << key;
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid weight key " << key;
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
WeightPtr weight_ptr = weights_[key];
|
|
|
|
WeightPtr weight_ptr = weights_[key];
|
|
|
|
WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray<T>>(weight_ptr->size(), 0);
|
|
|
|
WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray<T>>(weight_ptr->size(), 0);
|
|
|
|
copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size());
|
|
|
|
copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size());
|
|
|
|
|
|
|
|
tokens_[key] -= 1;
|
|
|
|
return copy_weight_ptr;
|
|
|
|
return copy_weight_ptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -560,12 +618,22 @@ inline bool ParameterServer<T>::ReadyForUpdateWeights() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
inline bool ParameterServer<T>::ReadyForAccumGrads() {
|
|
|
|
inline bool ParameterServer<T>::ReadyForPush(const Key &key) {
|
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
if (weights_.empty()) {
|
|
|
|
if (weights_.empty()) {
|
|
|
|
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
|
|
|
|
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
|
|
|
|
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
|
|
|
|
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return grad_accum_count_ < weights_.size();
|
|
|
|
return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
inline bool ParameterServer<T>::ReadyForPull(const Key &key) {
|
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
|
|
|
|
if (tokens_.count(key) == 0 || weights_[key] == 0) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid weight key " << key;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return tokens_[key] > 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
@ -576,6 +644,11 @@ inline void ParameterServer<T>::ResetGradAccumCount() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
inline std::mutex &ParameterServer<T>::mutex() {
|
|
|
|
|
|
|
|
return mutex_;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
::ps::Start(0);
|
|
|
|
::ps::Start(0);
|
|
|
|