|
|
|
@ -30,7 +30,6 @@
|
|
|
|
|
#include <random>
|
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
|
#include "backend/session/session_basic.h"
|
|
|
|
|
#include "backend/session/kernel_graph.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "backend/session/session_factory.h"
|
|
|
|
|
#include "frontend/parallel/ps/common.h"
|
|
|
|
@ -70,24 +69,32 @@ class ParameterServer {
|
|
|
|
|
ps_(new ::ps::KVServer<T>(0)),
|
|
|
|
|
handler_(nullptr),
|
|
|
|
|
func_graph_(nullptr),
|
|
|
|
|
kernel_graph_(nullptr),
|
|
|
|
|
sess_(nullptr),
|
|
|
|
|
thread_(nullptr) {}
|
|
|
|
|
~ParameterServer() = default;
|
|
|
|
|
ParameterServer(const ParameterServer &) = delete;
|
|
|
|
|
ParameterServer &operator=(const ParameterServer &) = delete;
|
|
|
|
|
|
|
|
|
|
struct ServerHandler {
|
|
|
|
|
class ServerHandler {
|
|
|
|
|
public:
|
|
|
|
|
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
|
|
|
|
|
void Init();
|
|
|
|
|
void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVServer<T> *server);
|
|
|
|
|
void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleInitWeights(const ::ps::KVPairs<T> &req_data);
|
|
|
|
|
void HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data);
|
|
|
|
|
void HandleInitInputsShape(const ::ps::KVPairs<T> &req_data);
|
|
|
|
|
void HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data);
|
|
|
|
|
void HandleInitWeights(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
|
|
|
|
|
|
|
|
|
ParameterServer *ps_;
|
|
|
|
|
typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVPairs<T> *res);
|
|
|
|
|
std::unordered_map<int, RequestHandler> handlers_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
bool Init(const FuncGraphPtr &func_graph);
|
|
|
|
@ -103,7 +110,6 @@ class ParameterServer {
|
|
|
|
|
WeightPtr weight(const Key &key);
|
|
|
|
|
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
|
|
|
|
|
int SumOfShapes(const std::vector<int> &shapes) const;
|
|
|
|
|
size_t PreComputeCapacity(const Keys &keys, const Lengths &lens);
|
|
|
|
|
bool ReadyForUpdateWeights();
|
|
|
|
|
bool ReadyForAccumGrads();
|
|
|
|
|
void ResetGradAccumCount();
|
|
|
|
@ -115,7 +121,6 @@ class ParameterServer {
|
|
|
|
|
std::unique_ptr<::ps::KVServer<T>> ps_;
|
|
|
|
|
std::unique_ptr<ServerHandler> handler_;
|
|
|
|
|
FuncGraphPtr func_graph_;
|
|
|
|
|
std::shared_ptr<session::KernelGraph> kernel_graph_;
|
|
|
|
|
std::shared_ptr<session::SessionBasic> sess_;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
|
|
|
|
@ -126,12 +131,7 @@ class ParameterServer {
|
|
|
|
|
std::unordered_map<Key, WeightPtr> weights_;
|
|
|
|
|
std::unordered_map<Key, WeightPtr> grads_;
|
|
|
|
|
std::unordered_map<Key, size_t> grads_accum_counter_;
|
|
|
|
|
// std::unordered_map<Key, EmbeddingTablePtr> embeddings_;
|
|
|
|
|
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
|
|
|
|
|
std::unordered_map<Key, size_t> embedding_row_lens_;
|
|
|
|
|
|
|
|
|
|
T learning_rate_;
|
|
|
|
|
T momentum_;
|
|
|
|
|
|
|
|
|
|
std::mutex mutex_;
|
|
|
|
|
std::condition_variable apply_grads_cv_;
|
|
|
|
@ -139,7 +139,7 @@ class ParameterServer {
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<std::thread> thread_;
|
|
|
|
|
|
|
|
|
|
friend struct ServerHandler;
|
|
|
|
|
friend class ServerHandler;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FuncGraph;
|
|
|
|
@ -147,33 +147,29 @@ template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVServer<T> *server) {
|
|
|
|
|
::ps::KVPairs<T> res;
|
|
|
|
|
if (req_meta.cmd == kInitWeightsCmd) {
|
|
|
|
|
MS_LOG(ERROR) << "handle init weights cmd" << std::endl;
|
|
|
|
|
HandleInitWeights(req_data);
|
|
|
|
|
} else if (req_meta.cmd == kInitWeightToOptimIdCmd) {
|
|
|
|
|
MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl;
|
|
|
|
|
HandleInitWeightToOptimId(req_data);
|
|
|
|
|
} else if (req_meta.cmd == kInitOptimInputsShapeCmd) {
|
|
|
|
|
MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl;
|
|
|
|
|
HandleInitInputsShape(req_data);
|
|
|
|
|
} else if (req_meta.cmd == kInitEmbeddingsCmd) {
|
|
|
|
|
MS_LOG(ERROR) << "handle init embedding cmd" << std::endl;
|
|
|
|
|
HandleInitEmbeddings(req_data);
|
|
|
|
|
} else if (req_meta.cmd == kEmbeddingLookupCmd) {
|
|
|
|
|
MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl;
|
|
|
|
|
HandleEmbeddingLookup(req_meta, req_data, &res);
|
|
|
|
|
if (handlers_.count(req_meta.cmd) > 0) {
|
|
|
|
|
auto &handler_ptr = handlers_[req_meta.cmd];
|
|
|
|
|
(this->*handler_ptr)(req_meta, req_data, &res);
|
|
|
|
|
} else if (req_meta.push) {
|
|
|
|
|
MS_LOG(ERROR) << "handle push req cmd" << std::endl;
|
|
|
|
|
HandlePushReq(req_meta, req_data);
|
|
|
|
|
HandlePushReq(req_meta, req_data, &res);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "handle pull req cmd" << std::endl;
|
|
|
|
|
HandlePullReq(req_meta, req_data, &res);
|
|
|
|
|
}
|
|
|
|
|
server->Response(req_meta, res);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data) {
|
|
|
|
|
void ParameterServer<T>::ServerHandler::Init() {
|
|
|
|
|
handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights;
|
|
|
|
|
handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
|
|
|
|
|
handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
|
|
|
|
|
handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
|
|
|
|
|
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
|
ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -186,7 +182,8 @@ void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVPairs<T> &req_data) {
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta,
|
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
|
size_t key_num = req_data.keys.size();
|
|
|
|
|
T *data_ptr = req_data.vals.data();
|
|
|
|
|
size_t pos = 0;
|
|
|
|
@ -205,7 +202,9 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVPairs<T>
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data) {
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta,
|
|
|
|
|
const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
|
size_t key_num = req_data.keys.size();
|
|
|
|
|
for (size_t i = 0; i < key_num; i++) {
|
|
|
|
|
Key key = req_data.keys[i];
|
|
|
|
@ -215,12 +214,14 @@ void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KV
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs<T> &req_data) {
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta,
|
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
|
ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data) {
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta,
|
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
|
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
|
|
|
|
|
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
|
|
|
|
|
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
|
|
|
|
@ -249,10 +250,10 @@ template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
|
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
for (size_t i = 0; i < req_data.vals.size(); i++) {
|
|
|
|
|
res->keys.push_back(req_data.vals[i]);
|
|
|
|
|
for (size_t i = 0; i < req_data.keys.size(); i++) {
|
|
|
|
|
res->keys.push_back(req_data.keys[i]);
|
|
|
|
|
}
|
|
|
|
|
ps_->DoEmbeddingLookup(key, req_data.vals, res);
|
|
|
|
|
ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -268,6 +269,7 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
|
|
|
|
|
func_graph_ = func_graph;
|
|
|
|
|
rank_id_ = ::ps::MyRank();
|
|
|
|
|
handler_.reset(new ServerHandler(this));
|
|
|
|
|
handler_->Init();
|
|
|
|
|
|
|
|
|
|
InitOptimInfoBuilders();
|
|
|
|
|
|
|
|
|
@ -364,7 +366,13 @@ void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
|
for (auto shape : input_shapes) {
|
|
|
|
|
total_dims *= shape;
|
|
|
|
|
}
|
|
|
|
|
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0.01);
|
|
|
|
|
|
|
|
|
|
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
|
|
|
|
|
std::default_random_engine engine;
|
|
|
|
|
std::normal_distribution<float> random(0, 0.01);
|
|
|
|
|
for (size_t i = 0; i < total_dims; i++) {
|
|
|
|
|
(*embedding)[i] = random(engine);
|
|
|
|
|
}
|
|
|
|
|
weights_[key] = embedding;
|
|
|
|
|
|
|
|
|
|
grads_accum_counter_[key] = 0;
|
|
|
|
@ -480,8 +488,13 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids,
|
|
|
|
|
inputs.push_back(indices);
|
|
|
|
|
embedding_table->addr = table_ptr->data();
|
|
|
|
|
embedding_table->size = table_ptr->size() * sizeof(T);
|
|
|
|
|
indices->addr = lookup_ids.data();
|
|
|
|
|
indices->size = lookup_ids.size() * sizeof(T);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<int[]> tmp_ids(new int[lookup_ids.size()]);
|
|
|
|
|
for (size_t i = 0; i < lookup_ids.size(); i++) {
|
|
|
|
|
tmp_ids[i] = static_cast<int>(lookup_ids[i]);
|
|
|
|
|
}
|
|
|
|
|
indices->addr = tmp_ids.get();
|
|
|
|
|
indices->size = lookup_ids.size() * sizeof(int);
|
|
|
|
|
|
|
|
|
|
std::vector<kernel::AddressPtr> workspaces;
|
|
|
|
|
std::vector<kernel::AddressPtr> outputs;
|
|
|
|
@ -506,20 +519,6 @@ int ParameterServer<T>::SumOfShapes(const std::vector<int> &shapes) const {
|
|
|
|
|
return sum;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
size_t ParameterServer<T>::PreComputeCapacity(const Keys &keys, const Lengths &lens) {
|
|
|
|
|
size_t capacity = 0;
|
|
|
|
|
for (size_t i = 0; i < keys.size(); i++) {
|
|
|
|
|
Key key = keys[i];
|
|
|
|
|
if (embedding_row_lens_.count(key) > 0) {
|
|
|
|
|
capacity += embedding_row_lens_[key] * lens[i];
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid embedding lookup id " << key;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return capacity;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline bool ParameterServer<T>::ReadyForUpdateWeights() {
|
|
|
|
|
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
|
|
|
|
|