From febbbf45c8d2d7c9bf2e37e0b0ac39fad5ceda39 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Wed, 24 Feb 2021 14:45:37 +0800 Subject: [PATCH] added server --- mindspore/ccsrc/ps/CMakeLists.txt | 1 + mindspore/ccsrc/ps/common.h | 2 + mindspore/ccsrc/ps/core/protos/ps.proto | 8 + .../ccsrc/ps/internal/parameter_server.cc | 706 ++++++++++++++++++ .../ccsrc/ps/internal/parameter_server.h | 179 +++++ tests/ut/cpp/CMakeLists.txt | 1 + 6 files changed, 897 insertions(+) create mode 100644 mindspore/ccsrc/ps/internal/parameter_server.cc create mode 100644 mindspore/ccsrc/ps/internal/parameter_server.h diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index c09960e291..8509d72b96 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -22,6 +22,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc") list(REMOVE_ITEM _PS_SRC_FILES "internal/worker.cc") + list(REMOVE_ITEM _PS_SRC_FILES "internal/parameter_server.cc") endif() if(NOT ENABLE_D) diff --git a/mindspore/ccsrc/ps/common.h b/mindspore/ccsrc/ps/common.h index e7db641864..062129ac04 100644 --- a/mindspore/ccsrc/ps/common.h +++ b/mindspore/ccsrc/ps/common.h @@ -76,6 +76,8 @@ constexpr int64_t kFinalizeCmd = 40; constexpr size_t kInvalidKey = UINT64_MAX; constexpr int64_t kInvalidID = -1; +using DataPtr = std::shared_ptr; +using VectorPtr = std::shared_ptr>; using Key = ::ps::Key; using Keys = ::ps::SArray; using Values = ::ps::SArray; diff --git a/mindspore/ccsrc/ps/core/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto index 7f293663a1..bc5c18246a 100644 --- a/mindspore/ccsrc/ps/core/protos/ps.proto +++ b/mindspore/ccsrc/ps/core/protos/ps.proto @@ -35,6 +35,13 @@ enum CommandCode { FINALIZE = 10; } +message ParamInitInfoMessage { + int32 param_type = 1; + uint64 global_seed = 2; + uint64 op_seed = 3; + float init_val = 4; +} + message KVMessage { repeated int32 keys = 2; repeated float values = 3; @@ -46,6 +53,7 @@ message EmbeddingTableMeta { repeated uint64 input_shape = 2; repeated uint64 indices_shape = 3; repeated uint64 output_shape = 4; + ParamInitInfoMessage info = 5; } message EmbeddingTableLookup { diff --git a/mindspore/ccsrc/ps/internal/parameter_server.cc b/mindspore/ccsrc/ps/internal/parameter_server.cc new file mode 100644 index 0000000000..acd57c173e --- /dev/null +++ b/mindspore/ccsrc/ps/internal/parameter_server.cc @@ -0,0 +1,706 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/internal/parameter_server.h" + +namespace mindspore { +namespace ps { +namespace internal { + +void ParameterServer::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; + core::ClusterMetadata::instance()->Init( + PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), + PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); + MS_LOG(INFO) << "PServer connected successfully."; + if (!PSContext::instance()->is_server()) { + MS_LOG(INFO) << "This is not the Server node."; + return; + } + Init(func_graph); + server_node_.Start(); + rank_id_ = server_node_.rank_id(); + PSContext::instance()->SetPSRankId(rank_id_); + thread_->join(); + SyncEmbeddingTables(); + MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; + server_node_.Finish(); + server_node_.Stop(); + MS_LOG(INFO) << "PServer finalized successfully."; +} + +bool ParameterServer::Init(const FuncGraphPtr &func_graph) { + pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); + worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); + func_graph_ = func_graph; + handler_.reset(new ServerHandler(this)); + handler_->Init(); + + InitOptimInfoBuilders(); + server_node_.set_handler(*handler_); + thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); + GetEmbeddingTableParamPtr(); + return true; +} + +void ParameterServer::InitOptimInfoBuilders() { + std::shared_ptr momentum_info_builder = std::make_shared(worker_num_); + std::shared_ptr sparse_adam_info_builder = + std::make_shared(worker_num_); + std::shared_ptr sparse_ftrl_info_builder = + std::make_shared(worker_num_); + optim_info_builders_[kApplyMomentum] = momentum_info_builder; + optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; + optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; +} + +void ParameterServer::InitWeightKeyToOptims(const Key &key, const int64_t &optim_id) { + if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") { + return; + } + weight_key_to_optims_[key] = Util::optimizer_name(optim_id); + weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id); + MS_LOG(INFO) << "Initializing optimizer id for key:" << key << ", optimizer name:" << weight_key_to_optims_[key] + << ", optimizer op name:" << weight_key_to_optim_op_[key]; +} + +void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { + InputsShapePtr inputs_shape = std::make_shared(); + MS_EXCEPTION_IF_NULL(inputs_shape); + InputsShapePtr original_inputs_shape = std::make_shared(); + MS_EXCEPTION_IF_NULL(original_inputs_shape); + int64_t val_idx = 0; + const Key &key = keys[0]; + MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key; + if (optim_inputs_shape_.count(key) == 0) { + original_optim_inputs_shape_[key] = original_inputs_shape; + optim_inputs_shape_[key] = inputs_shape; + } + for (size_t i = 0; i < keys.size(); i++) { + auto shape = std::make_shared>(); + MS_EXCEPTION_IF_NULL(shape); + auto original_shape = std::make_shared>(); + MS_EXCEPTION_IF_NULL(original_shape); + inputs_shape->push_back(shape); + original_inputs_shape->push_back(original_shape); + + for (int64_t j = 0; j < lengths[i]; j++) { + shape->push_back(values[val_idx]); + original_shape->push_back(values[val_idx++]); + } + } + if (weight_key_to_optims_.count(key) > 0) { + const std::string &optim_name = weight_key_to_optims_[key]; + const std::string &optim_op_name = weight_key_to_optim_op_[key]; + if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) { + const CNodePtr cnode = GetCNode(optim_op_name); + MS_EXCEPTION_IF_NULL(cnode); + if (optim_name == kSparseAdam) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_, worker_num_); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); + optimizers_[key] = optimizer; + } else if (optim_name == kSparseLazyAdam) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_, worker_num_); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); + optimizers_[key] = optimizer; + } else if (optim_name == kApplyMomentum) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_, worker_num_); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); + optimizers_[key] = optimizer; + } else if (optim_name == kSparseFtrl) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_, worker_num_); + optimizer->InitKernel(cnode, optim_inputs_shape_[key]); + optimizers_[key] = optimizer; + } + } + } +} + +void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { + MS_EXCEPTION_IF_NULL(weight); + if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { + MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; + weights_[key] = weight; + tokens_[key] = 0; + is_embedding_[key] = false; + } +} + +void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { + MS_EXCEPTION_IF_NULL(grad); + if (grads_.count(key) == 0) { + grads_[key] = grad; + grads_accum_counter_[key] = 0; + } +} + +void ParameterServer::InitEmbeddingTable( + const Key &key, const std::shared_ptr>>> &shapes, + const ParamInitInfo ¶m_init_info) { + MS_EXCEPTION_IF_NULL(shapes); + if (weights_.count(key) == 0) { + std::shared_ptr lookup = + std::make_shared(rank_id_, pserver_num_, worker_num_); + lookup->InitKernel(shapes); + embedding_lookup_ops_[key] = lookup; + + // Init embedding weight + const std::vector &input_shapes = lookup->input_sizes(); + size_t total_dims = + std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies()); + WeightPtr embedding = std::make_shared(total_dims, 0); + MS_EXCEPTION_IF_NULL(embedding); + float *embedding_data = embedding->data(); + std::default_random_engine engine; + std::normal_distribution random(0, 0.01); + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + if (param_init_info.param_type_ == kWeight) { + InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data); + } else if (param_init_info.param_type_ == kAccumulation) { + for (size_t i = 0; i < total_dims; i++) { + embedding_data[i] = param_init_info.init_val_; + } + } + } else { + for (size_t i = 0; i < total_dims; i++) { + embedding_data[i] = random(engine); + } + } + weights_[key] = embedding; + MS_LOG(DEBUG) << "The key:" << key << " the embedding:" << *embedding; + tokens_[key] = 0; + is_embedding_[key] = true; + + grads_accum_counter_[key] = 0; + } +} + +bool ParameterServer::HasWeight(const Key &key) { return (weights_.count(key) > 0 && !is_embedding_.count(key)); } + +void ParameterServer::Finalize() { + running_ = false; + apply_grads_cv_.notify_one(); +} + +void ParameterServer::UpdateWeights() { + while (true) { + MS_LOG(INFO) << "The running is:" << running_ << " the ready is:" << this->ReadyForUpdateWeights(); + std::unique_lock lock(mutex_); + apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; }); + if (!running_) { + break; + } + + for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { + Key key = iter->first; + WeightPtr weight_ptr = iter->second; + + std::shared_ptr optimizer = nullptr; + if (weight_key_to_optims_.count(key) > 0) { + optimizer = optimizers_[key]; + } + MS_EXCEPTION_IF_NULL(optimizer); + + std::shared_ptr optim_info = optim_infos_[key]; + if (optim_info != nullptr) { + const std::vector &inputs = optim_info->inputs(); + const std::vector &workspaces = optim_info->workspaces(); + const std::vector &outputs = optim_info->outputs(); + + std::vector> shapes = {}; + std::vector indices_shape = {}; + indices_shape.emplace_back(optim_info->indice_size()); + shapes.push_back(indices_shape); + + if (original_optim_inputs_shape_.count(key) != 0) { + std::transform( + (*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(), + std::back_inserter(shapes), + [](std::shared_ptr> input_shapes) -> std::vector { return *input_shapes; }); + } + optimizer->ReInit(shapes); + optim_info->ComputeMean(shapes, worker_num_, pserver_num_, rank_id_); + optimizer->Execute(inputs, workspaces, outputs); + optim_info->Reset(); + } + if (!is_embedding_[key]) { + tokens_[key] = worker_num_; + } + } + ResetGradAccumCount(); + } +} + +void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { + std::unique_lock lock(mutex_); + const Key &key = keys[0]; + bool no_sparse_grad = values.size() == 1 && values[0] == -100; + if (!no_sparse_grad) { + std::shared_ptr optim_info = optim_infos_[key]; + + // Create or update the optimizer info + std::shared_ptr pserver_kernel = optimizers_[key]; + if (pserver_kernel == nullptr) { + MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; + } + MS_EXCEPTION_IF_NULL(pserver_kernel); + optim_infos_[key] = optim_info; + } + + grads_accum_counter_[key] += 1; + if (grads_accum_counter_[key] == worker_num_) { + grad_accum_count_++; + } + if (ReadyForUpdateWeights()) { + apply_grads_cv_.notify_one(); + } +} + +WeightPtr ParameterServer::weight(const Key &key) { + std::unique_lock lock(mutex_); + if (weights_.count(key) == 0) { + MS_LOG(EXCEPTION) << "Invalid weight key " << key; + } + WeightPtr weight_ptr = weights_[key]; + MS_LOG(DEBUG) << "The weight ptr size is:" << weight_ptr->size(); + MS_EXCEPTION_IF_NULL(weight_ptr); + WeightPtr copy_weight_ptr = std::make_shared>(weight_ptr->size(), 0); + MS_EXCEPTION_IF_NULL(copy_weight_ptr); + copy_weight_ptr = weight_ptr; + tokens_[key] -= 1; + return copy_weight_ptr; +} + +void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res) { + std::unique_lock lock(mutex_); + MS_EXCEPTION_IF_NULL(res); + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding table key " << key; + return; + } + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; + return; + } + WeightPtr table_ptr = weights_[key]; + MS_EXCEPTION_IF_NULL(table_ptr); + std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; + MS_EXCEPTION_IF_NULL(table_lookup_op); + + // Update shapes of lookup operator + std::vector> shapes = {}; + std::vector indices_shape = {}; + indices_shape.emplace_back(lookup_ids.size()); + shapes.push_back(indices_shape); + table_lookup_op->ReInit(shapes); + + const std::vector output_shapes = table_lookup_op->output_sizes(); + std::vector inputs; + AddressPtr embedding_table = std::make_shared(); + MS_EXCEPTION_IF_NULL(embedding_table); + AddressPtr indices = std::make_shared(); + MS_EXCEPTION_IF_NULL(indices); + inputs.push_back(embedding_table); + inputs.push_back(indices); + embedding_table->addr = table_ptr->data(); + embedding_table->size = table_ptr->size() * sizeof(float); + + std::unique_ptr tmp_ids(new int[lookup_ids.size()]); + MS_EXCEPTION_IF_NULL(tmp_ids); + for (size_t i = 0; i < lookup_ids.size(); i++) { + tmp_ids[i] = static_cast(lookup_ids[i]); + } + indices->addr = tmp_ids.get(); + indices->size = lookup_ids.size() * sizeof(int); + + std::vector workspaces; + std::vector outputs; + AddressPtr output = std::make_shared(); + MS_EXCEPTION_IF_NULL(output); + std::shared_ptr addr = std::make_shared(output_shapes[0] / sizeof(float), 0); + MS_EXCEPTION_IF_NULL(addr); + + output->addr = addr->data(); + output->size = output_shapes[0]; + outputs.push_back(output); + + table_lookup_op->Execute(inputs, workspaces, outputs); + *res->mutable_values() = {addr->begin(), addr->end()}; + res->add_len(res->values_size()); +} + +void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) { + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding table key " << key; + return; + } + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; + return; + } + WeightPtr table_ptr = weights_[key]; + MS_EXCEPTION_IF_NULL(table_ptr); + std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; + MS_EXCEPTION_IF_NULL(table_lookup_op); + table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size()); +} + +inline bool ParameterServer::ReadyForUpdateWeights() { + return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); +} + +inline bool ParameterServer::ReadyForPush(const Key &key) { + std::unique_lock lock(mutex_); + if (weights_.empty()) { + MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send " + "kInitWeightsCmd command. 2.The Server failed to initialize weights."; + } + MS_LOG(INFO) << "the grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size() + << " the token:" << (tokens_[key] <= 0); + return grad_accum_count_ < weights_.size() && tokens_[key] <= 0; +} + +inline bool ParameterServer::ReadyForPull(const Key &key) { + std::unique_lock lock(mutex_); + if (tokens_.count(key) == 0 || weights_[key] == 0) { + MS_LOG(EXCEPTION) << "Invalid weight key " << key; + } + MS_LOG(INFO) << "ReadyForPull: " << (tokens_[key] > 0); + return tokens_[key] > 0; +} + +inline void ParameterServer::ResetGradAccumCount() { + grad_accum_count_ = 0; + for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { + grads_accum_counter_[iter->first] = 0; + } +} + +const CNodePtr ParameterServer::GetCNode(const std::string &name) const { + std::list cnodes = func_graph_->GetOrderedCnodes(); + for (CNodePtr cnode : cnodes) { + MS_EXCEPTION_IF_NULL(cnode); + std::string fullname = cnode->fullname_with_scope(); + if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) { + return cnode; + } + } + return nullptr; +} + +inline std::mutex &ParameterServer::mutex() { return mutex_; } + +void ParameterServer::GetEmbeddingTableParamPtr() { + MS_EXCEPTION_IF_NULL(func_graph_); + auto cnodes = func_graph_->GetOrderedCnodes(); + Key count = 0; + for (auto cnode : cnodes) { + MS_EXCEPTION_IF_NULL(cnode); + std::string cnode_name = AnfAlgo::GetCNodeName(cnode); + if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) { + auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); + if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) { + auto embedding_cnode = embedding_table->cast(); + embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0); + } + MS_EXCEPTION_IF_NULL(embedding_table); + if (embedding_table->isa()) { + MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; + embedding_tables_.insert(std::make_pair(count, embedding_table->cast())); + count++; + } + } + } +} + +void ParameterServer::SyncEmbeddingTables() { + for (auto embedding_table : embedding_tables_) { + Key key = embedding_table.first; + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key; + continue; + } + auto lookup = embedding_lookup_ops_[key]; + const std::vector &input_shapes = lookup->input_sizes(); + std::vector new_tensor_shape(input_shapes.begin(), input_shapes.end()); + + tensor::TensorPtr new_tensor = std::make_shared(kNumberTypeFloat32, new_tensor_shape); + MS_EXCEPTION_IF_NULL(new_tensor); + float *new_tensor_data_ptr = reinterpret_cast(new_tensor->data_c()); + size_t new_tensor_size = static_cast(new_tensor->data().nbytes()); + size_t embedding_table_size = weights_[key]->size() * sizeof(float); + if (new_tensor_size != embedding_table_size) { + MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size + << ", embedding_table size:" << embedding_table_size; + } + MS_EXCEPTION_IF_NULL(new_tensor_data_ptr); + MS_EXCEPTION_IF_NULL(weights_[key]->data()); + int64_t ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; + } + + auto paramter_tensor_ptr = embedding_table.second->default_param(); + MS_EXCEPTION_IF_NULL(paramter_tensor_ptr); + paramter_tensor_ptr->cast()->AssignValue(*new_tensor); + } +} + +void ParameterServer::ServerHandler::Init() { + handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights; + handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId; + handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape; + handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings; + handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; + handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; + handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; + handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings; + handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; + handlers_[kPushCmd] = &ServerHandler::HandlePushReq; + handlers_[kPullCmd] = &ServerHandler::HandlePullReq; +} + +void ParameterServer::ServerHandler::operator()(std::shared_ptr conn, + std::shared_ptr meta, DataPtr data, size_t size) { + auto output = std::make_shared>(); + MS_LOG(INFO) << "The command is:" << meta->user_cmd(); + if (handlers_.count(meta->user_cmd()) == 0) { + MS_LOG(EXCEPTION) << "The command:" << meta->user_cmd() << " is not supported!"; + } + + auto &handler_ptr = handlers_[meta->user_cmd()]; + (this->*handler_ptr)(data, size, output); + std::shared_ptr res(new unsigned char[output->size()]); + MS_LOG(DEBUG) << "The output size is:" << output->size(); + if (output->size() > 0) { + int ret = memcpy_s(res.get(), output->size(), output->data(), output->size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + } + + ps_->server_node_.Response(conn, meta, res, output->size()); + MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:" + << std::chrono::time_point_cast(std::chrono::high_resolution_clock::now()) + .time_since_epoch() + .count(); +} + +void ParameterServer::ServerHandler::HandlePushReq(DataPtr data, size_t size, VectorPtr res) { + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + Keys keys = {input.keys().begin(), input.keys().end()}; + Values values = {input.values().begin(), input.values().end()}; + Lengths lens = {input.len().begin(), input.len().end()}; + MS_LOG(DEBUG) << "The keys:" << keys << " the values:" << values << " the len:" << lens; + ps_->AccumGrad(keys, values, lens); +} + +void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, VectorPtr res) { + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + KVMessage res_data; + *res_data.mutable_keys() = input.keys(); + Key key = input.keys()[0]; + auto weight = ps_->weight(key); + *res_data.mutable_values() = {weight->begin(), weight->end()}; + res->resize(res_data.ByteSizeLong()); + int ret = + memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } +} + +void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size, VectorPtr res) { + std::unique_lock lock(ps_->mutex()); + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + int key_num = input.keys_size(); + const float *data_ptr = input.values().data(); + size_t pos = 0; + for (int i = 0; i < key_num; i++) { + Key key = input.keys()[i]; + size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i]; + MS_LOG(DEBUG) << "The data len:" << data_len; + + if (!ps_->HasWeight(key)) { + WeightPtr weight_ptr = std::make_shared>(data_ptr + pos, data_ptr + (pos + data_len)); + MS_LOG(DEBUG) << "The weight ptr:" << *weight_ptr; + MS_EXCEPTION_IF_NULL(weight_ptr); + ps_->InitWeight(key, weight_ptr); + + GradPtr grad_ptr = std::make_shared>(data_len, 0); + MS_EXCEPTION_IF_NULL(grad_ptr); + ps_->InitGrad(key, grad_ptr); + } + pos += data_len; + } +} + +void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res) { + std::unique_lock lock(ps_->mutex()); + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + size_t key_num = input.keys_size(); + for (size_t i = 0; i < key_num; i++) { + Key key = input.keys()[i]; + float val = input.values()[i]; + if (init_weight_to_optim_[key]) { + continue; + } else { + init_weight_to_optim_[key] = true; + } + ps_->InitWeightKeyToOptims(key, val); + } +} + +void ParameterServer::ServerHandler::HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res) { + std::unique_lock lock(ps_->mutex()); + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + const Key &key = input.keys()[0]; + if (init_optim_info_[key]) { + return; + } else { + init_optim_info_[key] = true; + } + Keys keys = {input.keys().begin(), input.keys().end()}; + Values values = {input.values().begin(), input.values().end()}; + Lengths lens = {input.len().begin(), input.len().end()}; + ps_->InitOptimInputsShape(keys, values, lens); +} + +void ParameterServer::ServerHandler::HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res) { + std::unique_lock lock(ps_->mutex()); + EmbeddingTableMeta embedding_table_meta; + embedding_table_meta.ParseFromArray(data.get(), size); + const Key &key = embedding_table_meta.key(); + MS_LOG(INFO) << "Initializing embedding table for key:" << key; + std::shared_ptr>>> shapes = + std::make_shared>>>(); + MS_EXCEPTION_IF_NULL(shapes); + std::shared_ptr> input_shape = std::make_shared>( + embedding_table_meta.input_shape().begin(), embedding_table_meta.input_shape().end()); + MS_EXCEPTION_IF_NULL(input_shape); + std::shared_ptr> indices_shape = std::make_shared>( + embedding_table_meta.indices_shape().begin(), embedding_table_meta.indices_shape().end()); + MS_EXCEPTION_IF_NULL(indices_shape); + std::shared_ptr> output_shape = std::make_shared>( + embedding_table_meta.output_shape().begin(), embedding_table_meta.output_shape().end()); + MS_EXCEPTION_IF_NULL(output_shape); + shapes->push_back(input_shape); + shapes->push_back(indices_shape); + shapes->push_back(output_shape); + + const ParamInitInfoMessage &info = embedding_table_meta.info(); + ParamInitInfo param_init_info; + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + param_init_info.param_type_ = static_cast(info.param_type()); + if (param_init_info.param_type_ == kWeight) { + param_init_info.global_seed_ = info.global_seed(); + param_init_info.op_seed_ = info.op_seed(); + } else if (param_init_info.param_type_ == kAccumulation) { + param_init_info.init_val_ = info.init_val(); + } + } + ps_->InitEmbeddingTable(key, shapes, param_init_info); +} + +void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res) { + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + const Key &key = input.keys()[0]; + bool ready = ps_->ReadyForPush(key); + MS_LOG(INFO) << "the ready is:" << ready; + KVMessage res_data; + res_data.add_keys(key); + res_data.add_values(ready); + res->resize(res_data.ByteSizeLong()); + int ret = + memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } +} + +void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res) { + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + const Key &key = input.keys()[0]; + bool ready = ps_->ReadyForPull(key); + KVMessage res_data; + res_data.add_keys(key); + res_data.add_values(ready); + res->resize(res_data.ByteSizeLong()); + int ret = + memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } +} + +void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res) { + MS_EXCEPTION_IF_NULL(res); + EmbeddingTableLookup input; + input.ParseFromArray(data.get(), size); + const Key &key = input.key(); + MS_LOG(DEBUG) << "The key is:" << key; + + KVMessage res_data; + std::vector keys = {input.keys().begin(), input.keys().end()}; + *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()}; + + ps_->DoEmbeddingLookup(key, keys, &res_data); + res->resize(res_data.ByteSizeLong()); + int ret = + memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } +} + +void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res) { + std::unique_lock lock(ps_->mutex()); + MS_EXCEPTION_IF_NULL(res); + KVMessage input; + input.ParseFromArray(data.get(), size); + const Key &key = input.keys()[0]; + const LookupIds &lookup_ids = {input.keys().begin() + 1, input.keys().end()}; + const Values &update_vals = {input.values().begin(), input.values().end()}; + ps_->UpdateEmbeddings(key, lookup_ids, update_vals); +} + +void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, VectorPtr res) { + MS_EXCEPTION_IF_NULL(res); + ps_->Finalize(); +} +} // namespace internal +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/internal/parameter_server.h b/mindspore/ccsrc/ps/internal/parameter_server.h new file mode 100644 index 0000000000..6fb25c7dc7 --- /dev/null +++ b/mindspore/ccsrc/ps/internal/parameter_server.h @@ -0,0 +1,179 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ +#define MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "backend/session/session_basic.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/session_factory.h" +#include "ps/optimizer_info.h" +#include "ps/optimizer_info_builder.h" +#include "ps/ps_context.h" +#include "runtime/device/cpu/kernel_select_cpu.h" +#include "utils/ms_context.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "ps/random_normal/random_normal.h" + +#include "ps/internal/constants.h" +#include "ps/util.h" +#include "ps/embedding_table_shard_metadata.h" +#include "utils/log_adapter.h" +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/core/server_node.h" + +namespace mindspore { +namespace ps { +namespace internal { + +class ParameterServer { + public: + static ParameterServer &GetInstance() { + static ParameterServer instance; + return instance; + } + + void Run(const FuncGraphPtr &func_graph); + + private: + ParameterServer() + : pserver_num_(0), + worker_num_(0), + rank_id_(0), + grad_accum_count_(0), + handler_(nullptr), + func_graph_(nullptr), + sess_(nullptr), + running_(true), + thread_(nullptr) {} + ~ParameterServer() = default; + ParameterServer(const ParameterServer &) = delete; + ParameterServer &operator=(const ParameterServer &) = delete; + + class ServerHandler { + public: + explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + ~ServerHandler() = default; + void Init(); + void operator()(std::shared_ptr conn, std::shared_ptr meta, DataPtr data, + size_t size); + void HandlePushReq(DataPtr data, size_t size, VectorPtr res); + void HandlePullReq(DataPtr data, size_t size, VectorPtr res); + void HandleInitWeights(DataPtr data, size_t size, VectorPtr res); + void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res); + void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res); + void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res); + void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res); + void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res); + void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res); + void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res); + void HandleFinalize(DataPtr data, size_t size, VectorPtr res); + + private: + ParameterServer *ps_; + typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res); + std::unordered_map handlers_; + std::unordered_map init_weights_; + std::unordered_map init_weight_to_optim_; + std::unordered_map init_optim_info_; + }; + + bool Init(const FuncGraphPtr &func_graph); + void InitOptimInfoBuilders(); + void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id); + void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); + void InitWeight(const Key &key, const WeightPtr &weight); + void InitGrad(const Key &key, const GradPtr &grad); + void InitEmbeddingTable(const Key &key, + const std::shared_ptr>>> &shapes, + const ParamInitInfo ¶m_init_info); + bool HasWeight(const Key &key); + void Finalize(); + void UpdateWeights(); + void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); + WeightPtr weight(const Key &key); + void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res); + void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); + bool ReadyForUpdateWeights(); + bool ReadyForPush(const Key &key); + bool ReadyForPull(const Key &key); + void ResetGradAccumCount(); + const CNodePtr GetCNode(const std::string &name) const; + std::mutex &mutex(); + void GetEmbeddingTableParamPtr(); + void SyncEmbeddingTables(); + + size_t pserver_num_; + size_t worker_num_; + size_t rank_id_; + size_t grad_accum_count_; + std::unique_ptr handler_; + FuncGraphPtr func_graph_; + std::shared_ptr sess_; + bool running_; + + std::unordered_map> optimizers_; + std::unordered_map optim_inputs_shape_; + std::unordered_map original_optim_inputs_shape_; + std::unordered_map> optim_infos_; + std::unordered_map> optim_info_builders_; + std::unordered_map weight_key_to_optims_; + std::unordered_map weight_key_to_optim_op_; + std::unordered_map weights_; + std::unordered_map is_embedding_; + std::unordered_map grads_; + std::unordered_map grads_accum_counter_; + std::unordered_map> embedding_lookup_ops_; + std::unordered_map tokens_; + + std::mutex mutex_; + std::condition_variable apply_grads_cv_; + + std::unique_ptr thread_; + core::ServerNode server_node_; + std::map embedding_tables_; + + friend class ServerHandler; +}; +} // namespace internal +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 87778a1214..72113f19f8 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -146,6 +146,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/worker.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parameter_server.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")