From: @anancds
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @cristoval
pull/12578/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c34fda33df

@ -22,6 +22,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc")
list(REMOVE_ITEM _PS_SRC_FILES "internal/worker.cc")
list(REMOVE_ITEM _PS_SRC_FILES "internal/parameter_server.cc")
endif()
if(NOT ENABLE_D)

@ -76,6 +76,8 @@ constexpr int64_t kFinalizeCmd = 40;
constexpr size_t kInvalidKey = UINT64_MAX;
constexpr int64_t kInvalidID = -1;
using DataPtr = std::shared_ptr<unsigned char>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = ::ps::Key;
using Keys = ::ps::SArray<Key>;
using Values = ::ps::SArray<float>;

@ -35,6 +35,13 @@ enum CommandCode {
FINALIZE = 10;
}
message ParamInitInfoMessage {
int32 param_type = 1;
uint64 global_seed = 2;
uint64 op_seed = 3;
float init_val = 4;
}
message KVMessage {
repeated int32 keys = 2;
repeated float values = 3;
@ -46,6 +53,7 @@ message EmbeddingTableMeta {
repeated uint64 input_shape = 2;
repeated uint64 indices_shape = 3;
repeated uint64 output_shape = 4;
ParamInitInfoMessage info = 5;
}
message EmbeddingTableLookup {

File diff suppressed because it is too large Load Diff

@ -0,0 +1,179 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
#define MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
#include <unistd.h>
#include <unordered_map>
#include <string>
#include <iostream>
#include <memory>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <thread>
#include <cmath>
#include <random>
#include <utility>
#include <list>
#include <map>
#include <functional>
#include "ir/func_graph.h"
#include "backend/session/session_basic.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/session_factory.h"
#include "ps/optimizer_info.h"
#include "ps/optimizer_info_builder.h"
#include "ps/ps_context.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/random_normal/random_normal.h"
#include "ps/internal/constants.h"
#include "ps/util.h"
#include "ps/embedding_table_shard_metadata.h"
#include "utils/log_adapter.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/server_node.h"
namespace mindspore {
namespace ps {
namespace internal {
class ParameterServer {
public:
static ParameterServer &GetInstance() {
static ParameterServer instance;
return instance;
}
void Run(const FuncGraphPtr &func_graph);
private:
ParameterServer()
: pserver_num_(0),
worker_num_(0),
rank_id_(0),
grad_accum_count_(0),
handler_(nullptr),
func_graph_(nullptr),
sess_(nullptr),
running_(true),
thread_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
ParameterServer &operator=(const ParameterServer &) = delete;
class ServerHandler {
public:
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
~ServerHandler() = default;
void Init();
void operator()(std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
size_t size);
void HandlePushReq(DataPtr data, size_t size, VectorPtr res);
void HandlePullReq(DataPtr data, size_t size, VectorPtr res);
void HandleInitWeights(DataPtr data, size_t size, VectorPtr res);
void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res);
void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res);
void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res);
void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res);
void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res);
void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res);
void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res);
void HandleFinalize(DataPtr data, size_t size, VectorPtr res);
private:
ParameterServer *ps_;
typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res);
std::unordered_map<int, RequestHandler> handlers_;
std::unordered_map<Key, bool> init_weights_;
std::unordered_map<Key, bool> init_weight_to_optim_;
std::unordered_map<Key, bool> init_optim_info_;
};
bool Init(const FuncGraphPtr &func_graph);
void InitOptimInfoBuilders();
void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
void InitWeight(const Key &key, const WeightPtr &weight);
void InitGrad(const Key &key, const GradPtr &grad);
void InitEmbeddingTable(const Key &key,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
const ParamInitInfo &param_init_info);
bool HasWeight(const Key &key);
void Finalize();
void UpdateWeights();
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
WeightPtr weight(const Key &key);
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
bool ReadyForUpdateWeights();
bool ReadyForPush(const Key &key);
bool ReadyForPull(const Key &key);
void ResetGradAccumCount();
const CNodePtr GetCNode(const std::string &name) const;
std::mutex &mutex();
void GetEmbeddingTableParamPtr();
void SyncEmbeddingTables();
size_t pserver_num_;
size_t worker_num_;
size_t rank_id_;
size_t grad_accum_count_;
std::unique_ptr<ServerHandler> handler_;
FuncGraphPtr func_graph_;
std::shared_ptr<session::SessionBasic> sess_;
bool running_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_;
std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
std::unordered_map<Key, std::string> weight_key_to_optims_;
std::unordered_map<Key, std::string> weight_key_to_optim_op_;
std::unordered_map<Key, WeightPtr> weights_;
std::unordered_map<Key, bool> is_embedding_;
std::unordered_map<Key, WeightPtr> grads_;
std::unordered_map<Key, size_t> grads_accum_counter_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
std::unordered_map<Key, uint64_t> tokens_;
std::mutex mutex_;
std::condition_variable apply_grads_cv_;
std::unique_ptr<std::thread> thread_;
core::ServerNode server_node_;
std::map<Key, ParameterPtr> embedding_tables_;
friend class ServerHandler;
};
} // namespace internal
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_

@ -146,6 +146,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/worker.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parameter_server.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")

Loading…
Cancel
Save