|
|
|
@ -70,6 +70,7 @@ class ParameterServer {
|
|
|
|
|
handler_(nullptr),
|
|
|
|
|
func_graph_(nullptr),
|
|
|
|
|
sess_(nullptr),
|
|
|
|
|
running_(true),
|
|
|
|
|
thread_(nullptr) {}
|
|
|
|
|
~ParameterServer() = default;
|
|
|
|
|
ParameterServer(const ParameterServer &) = delete;
|
|
|
|
@ -106,6 +107,7 @@ class ParameterServer {
|
|
|
|
|
void InitGrad(const Key &key, const GradPtr &grad);
|
|
|
|
|
void InitEmbeddingTable(const Key &key,
|
|
|
|
|
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
|
|
|
|
|
void Finalize();
|
|
|
|
|
void UpdateWeights();
|
|
|
|
|
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
|
|
|
|
|
WeightPtr weight(const Key &key);
|
|
|
|
@ -123,6 +125,7 @@ class ParameterServer {
|
|
|
|
|
std::unique_ptr<ServerHandler> handler_;
|
|
|
|
|
FuncGraphPtr func_graph_;
|
|
|
|
|
std::shared_ptr<session::SessionBasic> sess_;
|
|
|
|
|
bool running_;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
|
|
|
|
|
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
|
|
|
|
@ -261,7 +264,7 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
|
|
|
|
::ps::KVPairs<T> *res) {
|
|
|
|
|
::ps::Finalize(0, false);
|
|
|
|
|
ps_->Finalize();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -381,11 +384,20 @@ void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
|
grads_accum_counter_[key] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::Finalize() {
|
|
|
|
|
running_ = false;
|
|
|
|
|
apply_grads_cv_.notify_one();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::UpdateWeights() {
|
|
|
|
|
while (true) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
|
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); });
|
|
|
|
|
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
|
|
|
|
|
if (!running_) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
|
|
|
|
|
Key key = iter->first;
|
|
|
|
@ -550,6 +562,8 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
}
|
|
|
|
|
Init(func_graph);
|
|
|
|
|
thread_->join();
|
|
|
|
|
::ps::Finalize(0, true);
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ps
|
|
|
|
|
} // namespace parallel
|
|
|
|
|