|
|
|
@ -33,8 +33,7 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
}
|
|
|
|
|
Init(func_graph);
|
|
|
|
|
server_node_->Start();
|
|
|
|
|
rank_id_ = server_node_->rank_id();
|
|
|
|
|
PSContext::instance()->SetPSRankId(rank_id_);
|
|
|
|
|
PSContext::instance()->SetPSRankId(server_node_->rank_id());
|
|
|
|
|
thread_->join();
|
|
|
|
|
SyncEmbeddingTables();
|
|
|
|
|
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
|
|
|
|
@ -118,22 +117,22 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &value
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (optim_name == kSparseAdam) {
|
|
|
|
|
std::shared_ptr<PServerKernel> optimizer =
|
|
|
|
|
std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_, worker_num_);
|
|
|
|
|
std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
|
|
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
|
|
|
|
|
optimizers_[key] = optimizer;
|
|
|
|
|
} else if (optim_name == kSparseLazyAdam) {
|
|
|
|
|
std::shared_ptr<PServerKernel> optimizer =
|
|
|
|
|
std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_, worker_num_);
|
|
|
|
|
std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
|
|
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
|
|
|
|
|
optimizers_[key] = optimizer;
|
|
|
|
|
} else if (optim_name == kApplyMomentum) {
|
|
|
|
|
std::shared_ptr<PServerKernel> optimizer =
|
|
|
|
|
std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_, worker_num_);
|
|
|
|
|
std::make_shared<kernel::ps::ApplyMomentumPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
|
|
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
|
|
|
|
|
optimizers_[key] = optimizer;
|
|
|
|
|
} else if (optim_name == kSparseFtrl) {
|
|
|
|
|
std::shared_ptr<PServerKernel> optimizer =
|
|
|
|
|
std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_, worker_num_);
|
|
|
|
|
std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
|
|
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
|
|
|
|
|
optimizers_[key] = optimizer;
|
|
|
|
|
}
|
|
|
|
@ -144,7 +143,7 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &value
|
|
|
|
|
void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(weight);
|
|
|
|
|
if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
|
|
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_;
|
|
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << server_node_->rank_id();
|
|
|
|
|
weights_[key] = weight;
|
|
|
|
|
tokens_[key] = 0;
|
|
|
|
|
is_embedding_[key] = false;
|
|
|
|
@ -165,7 +164,7 @@ void ParameterServer::InitEmbeddingTable(
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shapes);
|
|
|
|
|
if (weights_.count(key) == 0) {
|
|
|
|
|
std::shared_ptr<PServerKernel> lookup =
|
|
|
|
|
std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_, worker_num_);
|
|
|
|
|
std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
|
|
|
|
|
lookup->InitKernel(shapes);
|
|
|
|
|
embedding_lookup_ops_[key] = lookup;
|
|
|
|
|
|
|
|
|
@ -244,7 +243,7 @@ void ParameterServer::UpdateWeights() {
|
|
|
|
|
[](std::shared_ptr<std::vector<size_t>> input_shapes) -> std::vector<size_t> { return *input_shapes; });
|
|
|
|
|
}
|
|
|
|
|
optimizer->ReInit(shapes);
|
|
|
|
|
optim_info->ComputeMean(shapes, worker_num_, pserver_num_, rank_id_);
|
|
|
|
|
optim_info->ComputeMean(shapes, worker_num_, pserver_num_, server_node_->rank_id());
|
|
|
|
|
optimizer->Execute(inputs, workspaces, outputs);
|
|
|
|
|
optim_info->Reset();
|
|
|
|
|
}
|
|
|
|
@ -296,7 +295,6 @@ WeightPtr ParameterServer::weight(const Key &key) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid weight key " << key;
|
|
|
|
|
}
|
|
|
|
|
WeightPtr weight_ptr = weights_[key];
|
|
|
|
|
MS_LOG(DEBUG) << "The weight ptr size is:" << weight_ptr->size();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(weight_ptr);
|
|
|
|
|
WeightPtr copy_weight_ptr = std::make_shared<std::vector<float>>(weight_ptr->size(), 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(copy_weight_ptr);
|
|
|
|
|