parent
e99c29c7d9
commit
db0a6f1e19
@ -1,22 +0,0 @@
|
||||
if(ENABLE_GITEE)
|
||||
set(REQ_URL "https://gitee.com/mirrors/ps-lite/repository/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip")
|
||||
set(MD5 "0d1543b8dcb0bc3610637e1643c94eb4")
|
||||
else()
|
||||
set(REQ_URL "https://github.com/dmlc/ps-lite/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip")
|
||||
set(MD5 "393c0e27b68bfaf96718caa3aa96f5a3")
|
||||
endif()
|
||||
|
||||
set(pslite_USE_STATIC_LIBS ON)
|
||||
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
set(pslite_CXXFLAGS "USE_IBVERBS=1")
|
||||
endif()
|
||||
mindspore_add_pkg(pslite
|
||||
LIBS ps
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/pslite/ps_lite.patch001
|
||||
ONLY_MAKE True
|
||||
ONLY_MAKE_INCS include/*
|
||||
ONLY_MAKE_LIBS build/*)
|
||||
include_directories(${pslite_INC})
|
||||
add_library(mindspore::pslite ALIAS pslite::ps)
|
@ -1,5 +0,0 @@
|
||||
mindspore_add_pkg(zeromq
|
||||
VER 4.1.4
|
||||
HEAD_ONLY ./
|
||||
URL https://raw.githubusercontent.com/mli/deps/master/build/zeromq-4.1.4.tar.gz
|
||||
MD5 a611ecc93fffeb6d058c0e6edf4ad4fb)
|
@ -1,140 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_PS_COMMON_H_
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "ps/ps.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
constexpr char kEnvCommType[] = "MS_COMM_TYPE";
|
||||
constexpr char kEnvInterface[] = "MS_INTERFACE";
|
||||
constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
|
||||
constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
|
||||
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
|
||||
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
|
||||
|
||||
constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE";
|
||||
constexpr char kDmlcInterface[] = "DMLC_INTERFACE";
|
||||
constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER";
|
||||
constexpr char kDmlcWorkerNum[] = "DMLC_NUM_WORKER";
|
||||
constexpr char kDmlcRole[] = "DMLC_ROLE";
|
||||
constexpr char kDmlcSchedulerHost[] = "DMLC_PS_ROOT_URI";
|
||||
constexpr char kDmlcSchedulerPort[] = "DMLC_PS_ROOT_PORT";
|
||||
|
||||
constexpr char kCommTypeOfIBVerbs[] = "ibverbs";
|
||||
constexpr char kCommTypeOfTCP[] = "zmq";
|
||||
constexpr char kRoleOfPServer[] = "server";
|
||||
constexpr char kRoleOfWorker[] = "worker";
|
||||
constexpr char kRoleOfScheduler[] = "scheduler";
|
||||
|
||||
constexpr char kLearningRate[] = "learning_rate";
|
||||
constexpr char kMomentum[] = "momentum";
|
||||
|
||||
constexpr char kApplyMomentum[] = "ApplyMomentum";
|
||||
constexpr char kSparseAdam[] = "Adam";
|
||||
constexpr char kSparseLazyAdam[] = "LazyAdam";
|
||||
constexpr char kSparseFtrl[] = "Ftrl";
|
||||
constexpr char kApplyMomentumOp[] = "Momentum";
|
||||
constexpr char kSparseAdamOp[] = "Adam";
|
||||
constexpr char kSparseLazyAdamOp[] = "LazyAdam";
|
||||
constexpr char kSparseFtrlOp[] = "FTRL";
|
||||
|
||||
constexpr int64_t kInitWeightsCmd = 10;
|
||||
constexpr int64_t kInitWeightToOptimIdCmd = 11;
|
||||
constexpr int64_t kInitOptimInputsShapeCmd = 12;
|
||||
constexpr int64_t kInitKeyToPushNodeIdCmd = 13;
|
||||
constexpr int64_t kInitEmbeddingsCmd = 20;
|
||||
constexpr int64_t kUpdateEmbeddingsCmd = 21;
|
||||
constexpr int64_t kCheckReadyForPushCmd = 25;
|
||||
constexpr int64_t kCheckReadyForPullCmd = 26;
|
||||
constexpr int64_t kEmbeddingLookupCmd = 30;
|
||||
constexpr int64_t kFinalizeCmd = 40;
|
||||
|
||||
constexpr size_t kInvalidKey = UINT64_MAX;
|
||||
constexpr int64_t kInvalidID = -1;
|
||||
|
||||
using DataPtr = std::shared_ptr<unsigned char>;
|
||||
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
|
||||
using Key = ::ps::Key;
|
||||
using Keys = ::ps::SArray<Key>;
|
||||
using Values = ::ps::SArray<float>;
|
||||
using ValuesPtr = std::shared_ptr<Values>;
|
||||
using Weight = ::ps::SArray<float>;
|
||||
using Grad = ::ps::SArray<float>;
|
||||
using LookupIds = ::ps::SArray<Key>;
|
||||
using Lengths = ::ps::SArray<int>;
|
||||
using WeightPtr = std::shared_ptr<Weight>;
|
||||
using GradPtr = std::shared_ptr<Grad>;
|
||||
using InputsShape = std::vector<std::shared_ptr<std::vector<size_t>>>;
|
||||
using InputsShapePtr = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>>;
|
||||
|
||||
constexpr size_t INDEX_NOT_SEND = UINT_MAX;
|
||||
using OptimOriginIdx = std::map<std::string, size_t>;
|
||||
using OptimPSSendIdx = std::map<std::string, size_t>;
|
||||
|
||||
const OptimOriginIdx kMomentumOriginIdx = {{"weight", 0}, {"accum", 1}, {"lr", 2}, {"grad", 3}, {"momentum", 4}};
|
||||
const OptimPSSendIdx kMomentumPSSendIdx = {
|
||||
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"lr", 0}, {"grad", 1}, {"momentum", 2}};
|
||||
|
||||
const OptimOriginIdx kSparseAdamOriginIdx = {{"weight", 0}, {"m", 1}, {"v", 2}, {"beta1_power", 3},
|
||||
{"beta2_power", 4}, {"lr", 5}, {"beta1", 6}, {"beta2", 7},
|
||||
{"eps", 8}, {"grad", 9}, {"indices", 10}};
|
||||
const OptimPSSendIdx kSparseAdamPSSendIdx = {{"weight", INDEX_NOT_SEND},
|
||||
{"m", INDEX_NOT_SEND},
|
||||
{"v", INDEX_NOT_SEND},
|
||||
{"beta1_power", 0},
|
||||
{"beta2_power", 1},
|
||||
{"lr", 2},
|
||||
{"beta1", 3},
|
||||
{"beta2", 4},
|
||||
{"eps", 5},
|
||||
{"grad", 6},
|
||||
{"indices", 7}};
|
||||
|
||||
const OptimOriginIdx kSparseFtrlOriginIdx = {{"weight", 0}, {"accum", 1}, {"linear", 2}, {"grad", 3}, {"indices", 4}};
|
||||
const OptimPSSendIdx kSparseFtrlPSSendIdx = {
|
||||
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"linear", INDEX_NOT_SEND}, {"grad", 0}, {"indices", 1}};
|
||||
|
||||
const std::map<std::string, OptimOriginIdx> kOptimToOriginIdx = {{kApplyMomentum, kMomentumOriginIdx},
|
||||
{kSparseAdam, kSparseAdamOriginIdx},
|
||||
{kSparseLazyAdam, kSparseAdamOriginIdx},
|
||||
{kSparseFtrl, kSparseFtrlOriginIdx}};
|
||||
const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum, kMomentumPSSendIdx},
|
||||
{kSparseAdam, kSparseAdamPSSendIdx},
|
||||
{kSparseLazyAdam, kSparseAdamPSSendIdx},
|
||||
{kSparseFtrl, kSparseFtrlPSSendIdx}};
|
||||
|
||||
#define EXC_IF_VEC_IDX_OOB(vec, idx) \
|
||||
{ \
|
||||
size_t vec_size = vec.size(); \
|
||||
if (idx >= vec_size) { \
|
||||
MS_LOG(EXCEPTION) << "Vector " << #vec << " size is " << vec_size << ". So index " << idx \
|
||||
<< " is out of bound."; \
|
||||
} \
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_COMMON_H_
|
@ -1,179 +0,0 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
|
||||
|
||||
#include <unistd.h>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <thread>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "ps/optimizer_info.h"
|
||||
#include "ps/optimizer_info_builder.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/random_normal/random_normal.h"
|
||||
|
||||
#include "ps/internal/constants.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/embedding_table_shard_metadata.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/server_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace internal {
|
||||
|
||||
class ParameterServer {
|
||||
public:
|
||||
static ParameterServer &GetInstance() {
|
||||
static ParameterServer instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
ParameterServer()
|
||||
: pserver_num_(0),
|
||||
worker_num_(0),
|
||||
rank_id_(0),
|
||||
grad_accum_count_(0),
|
||||
handler_(nullptr),
|
||||
func_graph_(nullptr),
|
||||
sess_(nullptr),
|
||||
running_(true),
|
||||
thread_(nullptr) {}
|
||||
~ParameterServer() = default;
|
||||
ParameterServer(const ParameterServer &) = delete;
|
||||
ParameterServer &operator=(const ParameterServer &) = delete;
|
||||
|
||||
class ServerHandler {
|
||||
public:
|
||||
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
|
||||
~ServerHandler() = default;
|
||||
void Init();
|
||||
void operator()(std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
|
||||
size_t size);
|
||||
void HandlePushReq(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandlePullReq(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitWeights(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleFinalize(DataPtr data, size_t size, VectorPtr res);
|
||||
|
||||
private:
|
||||
ParameterServer *ps_;
|
||||
typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res);
|
||||
std::unordered_map<int, RequestHandler> handlers_;
|
||||
std::unordered_map<Key, bool> init_weights_;
|
||||
std::unordered_map<Key, bool> init_weight_to_optim_;
|
||||
std::unordered_map<Key, bool> init_optim_info_;
|
||||
};
|
||||
|
||||
bool Init(const FuncGraphPtr &func_graph);
|
||||
void InitOptimInfoBuilders();
|
||||
void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
|
||||
void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
|
||||
void InitWeight(const Key &key, const WeightPtr &weight);
|
||||
void InitGrad(const Key &key, const GradPtr &grad);
|
||||
void InitEmbeddingTable(const Key &key,
|
||||
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
|
||||
const ParamInitInfo ¶m_init_info);
|
||||
bool HasWeight(const Key &key);
|
||||
void Finalize();
|
||||
void UpdateWeights();
|
||||
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
|
||||
WeightPtr weight(const Key &key);
|
||||
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
|
||||
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
|
||||
bool ReadyForUpdateWeights();
|
||||
bool ReadyForPush(const Key &key);
|
||||
bool ReadyForPull(const Key &key);
|
||||
void ResetGradAccumCount();
|
||||
const CNodePtr GetCNode(const std::string &name) const;
|
||||
std::mutex &mutex();
|
||||
void GetEmbeddingTableParamPtr();
|
||||
void SyncEmbeddingTables();
|
||||
|
||||
size_t pserver_num_;
|
||||
size_t worker_num_;
|
||||
size_t rank_id_;
|
||||
size_t grad_accum_count_;
|
||||
std::unique_ptr<ServerHandler> handler_;
|
||||
FuncGraphPtr func_graph_;
|
||||
std::shared_ptr<session::SessionBasic> sess_;
|
||||
bool running_;
|
||||
|
||||
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
|
||||
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
|
||||
std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_;
|
||||
std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
|
||||
std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
|
||||
std::unordered_map<Key, std::string> weight_key_to_optims_;
|
||||
std::unordered_map<Key, std::string> weight_key_to_optim_op_;
|
||||
std::unordered_map<Key, WeightPtr> weights_;
|
||||
std::unordered_map<Key, bool> is_embedding_;
|
||||
std::unordered_map<Key, WeightPtr> grads_;
|
||||
std::unordered_map<Key, size_t> grads_accum_counter_;
|
||||
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
|
||||
std::unordered_map<Key, uint64_t> tokens_;
|
||||
|
||||
std::mutex mutex_;
|
||||
std::condition_variable apply_grads_cv_;
|
||||
|
||||
std::unique_ptr<std::thread> thread_;
|
||||
core::ServerNode server_node_;
|
||||
std::map<Key, ParameterPtr> embedding_tables_;
|
||||
|
||||
friend class ServerHandler;
|
||||
};
|
||||
} // namespace internal
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
|
@ -1,157 +0,0 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
|
||||
#define MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/internal/constants.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/core/worker_node.h"
|
||||
#include "ps/embedding_table_shard_metadata.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace internal {
|
||||
|
||||
class Worker {
|
||||
public:
|
||||
static Worker &GetInstance() {
|
||||
static Worker instance;
|
||||
return instance;
|
||||
}
|
||||
using Callback = std::function<void()>;
|
||||
using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
|
||||
using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;
|
||||
|
||||
using EmbeddingPartitioner = std::function<void(
|
||||
const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
|
||||
using KVPartitioner =
|
||||
std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
|
||||
|
||||
void Run();
|
||||
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
|
||||
void Pull(const size_t key, void *dev_addr, const size_t size);
|
||||
size_t SetParamKey(const std::string ¶m_name);
|
||||
size_t GetParamKey(const std::string ¶m_name);
|
||||
void SetParamInitInServer(const std::string ¶m_name, bool init_in_server);
|
||||
bool GetParamInitInServer(const std::string ¶m_name);
|
||||
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
|
||||
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
|
||||
void AddEmbeddingTable(const Key &key, const size_t &row_count);
|
||||
void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape);
|
||||
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
|
||||
void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
|
||||
int64_t cmd);
|
||||
void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
|
||||
const std::vector<float> &vals);
|
||||
|
||||
bool running() { return running_; }
|
||||
void Finalize();
|
||||
|
||||
private:
|
||||
Worker() : running_(false), key_cnt_(0) {}
|
||||
~Worker() = default;
|
||||
Worker(const Worker &) = delete;
|
||||
Worker &operator=(const Worker &) = delete;
|
||||
|
||||
void Initialize();
|
||||
bool IsKeyInit(const size_t key);
|
||||
void AddKeyToServerId(const Key &key);
|
||||
void AddKeyByHashMod(const Key &key);
|
||||
void InitPSOptimId(const size_t param_key);
|
||||
void InitPSOptimInputShapes(const size_t key);
|
||||
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
|
||||
bool IsReadyForPush(const Key &key);
|
||||
bool IsReadyForPull(const Key &key);
|
||||
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
|
||||
const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
|
||||
const size_t segment_size, float *gradient, int *indices);
|
||||
void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
|
||||
const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);
|
||||
|
||||
void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
|
||||
int command = 0, int64_t priority = 0);
|
||||
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
|
||||
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
|
||||
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
|
||||
int64_t priority = 0);
|
||||
|
||||
void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
|
||||
void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
|
||||
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);
|
||||
|
||||
int64_t server_num_;
|
||||
bool running_;
|
||||
std::mutex running_mutex_;
|
||||
size_t key_cnt_;
|
||||
std::map<std::string, size_t> param_to_key_;
|
||||
std::map<size_t, bool> init_keys_;
|
||||
std::map<size_t, int64_t> key_to_optimId_;
|
||||
std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
|
||||
std::map<std::string, bool> param_to_init_in_server_;
|
||||
core::WorkerNode worker_node_;
|
||||
|
||||
EmbeddingPartitioner lookup_partitioner_;
|
||||
KVPartitioner sparse_partitioner_;
|
||||
KVPartitioner round_robin_partitioner_;
|
||||
KVPartitioner worker_init_embedding_partitioner_;
|
||||
KVPartitioner update_embedding_partitioner_;
|
||||
KVPartitioner broadcast_partitioner_;
|
||||
std::unordered_map<Key, int64_t> key_to_server_id_;
|
||||
std::unordered_map<Key, size_t> embedding_row_cnt_;
|
||||
|
||||
std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
|
||||
};
|
||||
|
||||
static Worker &worker = Worker::GetInstance();
|
||||
} // namespace internal
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue