From db0a6f1e197e2ad48e46f8e63c610feeec764b4a Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Sat, 27 Feb 2021 17:10:00 +0800 Subject: [PATCH] replace ps-lite --- cmake/external_libs/pslite.cmake | 22 - cmake/external_libs/zeromq.cmake | 5 - cmake/mind_expression.cmake | 4 - mindspore/ccsrc/CMakeLists.txt | 4 +- .../cpu/ps/embedding_look_up_proxy_kernel.cc | 19 +- .../kernel_compiler/cpu/ps/pull_kernel.h | 2 +- .../ccsrc/backend/session/cpu_session.cc | 5 +- .../ccsrc/backend/session/session_basic.cc | 7 +- .../ccsrc/frontend/parallel/step_parallel.cc | 3 +- .../ccsrc/minddata/dataset/CMakeLists.txt | 1 - mindspore/ccsrc/pipeline/jit/action.cc | 4 +- mindspore/ccsrc/pipeline/jit/pass.cc | 3 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 16 +- mindspore/ccsrc/ps/CMakeLists.txt | 4 +- mindspore/ccsrc/ps/common.h | 140 --- mindspore/ccsrc/ps/{internal => }/constants.h | 12 +- mindspore/ccsrc/ps/core/cluster_metadata.cc | 4 +- mindspore/ccsrc/ps/core/cluster_metadata.h | 4 +- mindspore/ccsrc/ps/core/comm_util.cc | 4 +- mindspore/ccsrc/ps/core/node_manager.cc | 2 +- .../ccsrc/ps/internal/parameter_server.h | 179 --- mindspore/ccsrc/ps/internal/worker.h | 157 --- mindspore/ccsrc/ps/optimizer_info.cc | 24 +- mindspore/ccsrc/ps/optimizer_info.h | 2 +- mindspore/ccsrc/ps/optimizer_info_builder.cc | 28 +- .../ps/{internal => }/parameter_server.cc | 35 +- mindspore/ccsrc/ps/parameter_server.h | 1024 +++-------------- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 58 +- .../ccsrc/ps/ps_cache/ps_cache_manager.h | 8 +- mindspore/ccsrc/ps/ps_context.cc | 8 +- mindspore/ccsrc/ps/ps_context.h | 1 + mindspore/ccsrc/ps/scheduler.cc | 9 +- mindspore/ccsrc/ps/scheduler.h | 6 + mindspore/ccsrc/ps/util.cc | 42 +- mindspore/ccsrc/ps/util.h | 3 - mindspore/ccsrc/ps/{internal => }/worker.cc | 17 +- mindspore/ccsrc/ps/worker.h | 397 ++----- mindspore/ccsrc/ps/worker_proxy.h | 873 -------------- .../runtime/device/kernel_runtime_manager.cc | 2 +- .../scripts/run_parameter_server_train.sh | 1 - .../scripts/run_parameter_server_train_gpu.sh | 1 - .../run_parameter_server_train_cluster.sh | 1 - .../run_parameter_server_train_distribute.sh | 1 - .../run_parameter_server_train_standalone.sh | 1 - .../ps/cmp_sparse_embedding/shell_run_test.sh | 3 +- tests/st/ps/full_ps/shell_run_test.sh | 3 +- tests/st/ps/multi_full_ps/shell_run_test.sh | 3 +- tests/st/ps/part_ps/shell_run_test.sh | 3 +- tests/ut/cpp/CMakeLists.txt | 2 + third_party/patch/pslite/ps_lite.patch001 | 255 ---- 50 files changed, 416 insertions(+), 2996 deletions(-) delete mode 100644 cmake/external_libs/pslite.cmake delete mode 100644 cmake/external_libs/zeromq.cmake delete mode 100644 mindspore/ccsrc/ps/common.h rename mindspore/ccsrc/ps/{internal => }/constants.h (96%) delete mode 100644 mindspore/ccsrc/ps/internal/parameter_server.h delete mode 100644 mindspore/ccsrc/ps/internal/worker.h rename mindspore/ccsrc/ps/{internal => }/parameter_server.cc (96%) rename mindspore/ccsrc/ps/{internal => }/worker.cc (99%) delete mode 100644 mindspore/ccsrc/ps/worker_proxy.h delete mode 100644 third_party/patch/pslite/ps_lite.patch001 diff --git a/cmake/external_libs/pslite.cmake b/cmake/external_libs/pslite.cmake deleted file mode 100644 index 7e64563fbb..0000000000 --- a/cmake/external_libs/pslite.cmake +++ /dev/null @@ -1,22 +0,0 @@ -if(ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/ps-lite/repository/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip") - set(MD5 "0d1543b8dcb0bc3610637e1643c94eb4") -else() - set(REQ_URL "https://github.com/dmlc/ps-lite/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip") - set(MD5 "393c0e27b68bfaf96718caa3aa96f5a3") -endif() - -set(pslite_USE_STATIC_LIBS ON) -if(${ENABLE_IBVERBS} STREQUAL "ON") - set(pslite_CXXFLAGS "USE_IBVERBS=1") -endif() -mindspore_add_pkg(pslite - LIBS ps - URL ${REQ_URL} - MD5 ${MD5} - PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/pslite/ps_lite.patch001 - ONLY_MAKE True - ONLY_MAKE_INCS include/* - ONLY_MAKE_LIBS build/*) -include_directories(${pslite_INC}) -add_library(mindspore::pslite ALIAS pslite::ps) diff --git a/cmake/external_libs/zeromq.cmake b/cmake/external_libs/zeromq.cmake deleted file mode 100644 index 122f1ee90c..0000000000 --- a/cmake/external_libs/zeromq.cmake +++ /dev/null @@ -1,5 +0,0 @@ -mindspore_add_pkg(zeromq - VER 4.1.4 - HEAD_ONLY ./ - URL https://raw.githubusercontent.com/mli/deps/master/build/zeromq-4.1.4.tar.gz - MD5 a611ecc93fffeb6d058c0e6edf4ad4fb) diff --git a/cmake/mind_expression.cmake b/cmake/mind_expression.cmake index f7dec90a5f..9dee8f9096 100644 --- a/cmake/mind_expression.cmake +++ b/cmake/mind_expression.cmake @@ -32,10 +32,6 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake) if(USE_GLOG) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake) endif() -if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zeromq.cmake) - include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pslite.cmake) -endif() find_package(Python3) include_directories(${Python3_INCLUDE_DIRS}) diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 7b02ae0194..6eeb1f0cc7 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -339,8 +339,8 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load) else() if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) - target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf - mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) + target_link_libraries(mindspore proto_input mindspore::protobuf + mindspore::event mindspore::event_pthreads) target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) if(${ENABLE_IBVERBS} STREQUAL "ON") target_link_libraries(mindspore ibverbs rdmacm) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index 179c4f0a04..f533195803 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -17,6 +17,7 @@ #include #include #include "ps/worker.h" +#include "ps/util.h" namespace mindspore { namespace kernel { @@ -35,7 +36,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { << input_shape << " is too large."; } - if (mindspore::ps::Util::IsRoleOfWorker()) { + if (mindspore::ps::PSContext::instance()->is_worker()) { key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); } std::vector keys{key_, key_, key_}; @@ -50,9 +51,10 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; std::vector lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), SizeToLong(output_shape.size())}; - if (mindspore::ps::Util::IsRoleOfWorker()) { + if (mindspore::ps::PSContext::instance()->is_worker()) { mindspore::ps::worker.AddEmbeddingTable(key_, input_shape[axis]); - mindspore::ps::worker.InitPSEmbeddingTable(keys, values, lens); + mindspore::ps::ParamInitInfoMessage info; + mindspore::ps::worker.InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info); } } @@ -70,17 +72,16 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector &i size_t input_size = inputs[1]->size; size_t output_size = outputs[0]->size; - size_t size = input_size / sizeof(float); - ::ps::SArray lookup_ids(size, 0); - ::ps::SArray lengths{size}; - ::ps::SArray lookup_result(output_size / sizeof(float), 0); + size_t size = input_size / sizeof(int); + std::vector lookup_ids(size, 0); + std::vector lengths{SizeToInt(size)}; + std::vector lookup_result(output_size / sizeof(float), 0); auto ret = memcpy_s(lookup_ids.data(), lookup_ids.size() * sizeof(int), indices_addr, input_size); if (ret != EOK) { MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; return false; } - mindspore::ps::worker.DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result, - mindspore::ps::kEmbeddingLookupCmd); + mindspore::ps::worker.DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); auto ret2 = memcpy_s(output_addr, outputs[0]->size, lookup_result.data(), output_size); if (ret2 != EOK) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h index 5fa400b6ff..221d8cbcea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h @@ -62,7 +62,7 @@ class PullKernel : public CPUKernel { MS_EXCEPTION_IF_NULL(param_node); param_name_ = param_node->fullname_with_scope(); - if (mindspore::ps::Util::IsRoleOfWorker()) { + if (mindspore::ps::PSContext::instance()->is_worker()) { key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); } InitSizeLists(); diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 2663022cd9..36c4697eef 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -30,6 +30,7 @@ #include "backend/optimizer/pass/replace_node_by_proxy.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/util.h" +#include "ps/ps_context.h" #endif namespace mindspore { @@ -75,9 +76,9 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr MS_LOG(INFO) << "Set kernel info"; SetKernelInfo(graph.get()); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsParamServerMode()) { + if (ps::PSContext::instance()->is_ps_mode()) { AssignParamKey(graph); - if (ps::Util::IsRoleOfWorker()) { + if (ps::PSContext::instance()->is_worker()) { Optimize(graph); } } diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 82b9ddfc9b..edfb4d2c3d 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -41,8 +41,9 @@ #include "utils/trace_base.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/ps_cache/ps_cache_manager.h" -#include "ps/common.h" +#include "ps/constants.h" #include "ps/util.h" +#include "ps/ps_context.h" #include "abstract/abstract_value.h" #endif @@ -2287,7 +2288,7 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { - if (!ps::Util::IsRoleOfWorker()) { + if (!ps::PSContext::instance()->is_worker()) { return; } CheckPSModeConsistence(kernel_graph); @@ -2384,7 +2385,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector &inputs_const) { - if (!ps::Util::IsRoleOfWorker()) { + if (!ps::PSContext::instance()->is_worker()) { return; } std::vector inputs(inputs_const); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 3395115c6b..f9cf954e48 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -48,6 +48,7 @@ #include "mindspore/core/utils/parallel_node_check.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/util.h" +#include "ps/ps_context.h" #endif using mindspore::tensor::Tensor; @@ -3283,7 +3284,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { + if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { return false; } #endif diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index 21861b5ee3..6e98b3ae2a 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -288,7 +288,6 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") else() target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) - target_link_libraries(_c_dataengine PRIVATE mindspore::pslite ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) if(${ENABLE_IBVERBS} STREQUAL "ON") target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) endif() diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index ce8e692d6b..14627caba6 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -460,7 +460,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) { bool StartPSServerAction(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); - auto &ps = ps::ParameterServer::GetInstance(); + auto &ps = ps::ParameterServer::GetInstance(); ps.Run(func_graph); return true; } @@ -626,7 +626,7 @@ std::vector VmPipeline() { actions.emplace_back(std::make_pair("validate", ValidateAction)); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsRoleOfWorker()) { + if (ps::PSContext::instance()->is_worker()) { actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); } #endif diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 69361d6671..7a5e29c1b6 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -43,6 +43,7 @@ #include "pipeline/jit/static_analysis/auto_monad.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/util.h" +#include "ps/ps_context.h" #endif namespace mindspore { @@ -406,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) { bool AddCacheEmbeddingPass(const ResourcePtr &res) { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsParamServerMode()) { + if (ps::PSContext::instance()->is_ps_mode()) { return true; } #endif diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index ea41db48a3..2b57e59421 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -49,7 +49,7 @@ #include "utils/shape_utils.h" #include "utils/info.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) -#include "ps/common.h" +#include "ps/constants.h" #include "ps/util.h" #include "ps/worker.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" @@ -492,14 +492,11 @@ std::vector GetPipline(const ResourcePtr &resource, const std::strin std::string backend = MsContext::GetInstance()->backend_policy(); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (mindspore::ps::Util::IsParamServerMode()) { - mindspore::ps::Util::SetInternalEnvVar(); - } - if (ps::Util::IsRoleOfPServer()) { + if (ps::PSContext::instance()->is_server()) { resource->results()[kBackend] = compile::CreateBackend(); return PServerPipeline(); } - if (ps::Util::IsRoleOfScheduler()) { + if (ps::PSContext::instance()->is_scheduler()) { return PSchedulerPipeline(); } #endif @@ -978,7 +975,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, bool need_run) { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) { + if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) { return true; } #endif @@ -1030,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc ConfigManager::GetInstance().set_iter_num(size); // PS cache does not support loop sink. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { + if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); ConfigManager::GetInstance().set_iter_num(1); } @@ -1151,10 +1148,11 @@ void ClearResAtexit() { pynative::ClearPyNativeSession(); session::ClearPythonParasMap(); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { + if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { if (ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::ps_cache_instance.Finalize(); } + MS_LOG(INFO) << "ps::worker.Finalize"; ps::worker.Finalize(); } #endif diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 8509d72b96..091b77193e 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -21,8 +21,8 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc") - list(REMOVE_ITEM _PS_SRC_FILES "internal/worker.cc") - list(REMOVE_ITEM _PS_SRC_FILES "internal/parameter_server.cc") + list(REMOVE_ITEM _PS_SRC_FILES "worker.cc") + list(REMOVE_ITEM _PS_SRC_FILES "parameter_server.cc") endif() if(NOT ENABLE_D) diff --git a/mindspore/ccsrc/ps/common.h b/mindspore/ccsrc/ps/common.h deleted file mode 100644 index 062129ac04..0000000000 --- a/mindspore/ccsrc/ps/common.h +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PS_COMMON_H_ -#define MINDSPORE_CCSRC_PS_COMMON_H_ - -#include - -#include -#include -#include -#include -#include - -#include "ps/ps.h" - -namespace mindspore { -namespace ps { -constexpr char kEnvCommType[] = "MS_COMM_TYPE"; -constexpr char kEnvInterface[] = "MS_INTERFACE"; -constexpr char kEnvPServerNum[] = "MS_SERVER_NUM"; -constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM"; -constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST"; -constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT"; - -constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE"; -constexpr char kDmlcInterface[] = "DMLC_INTERFACE"; -constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER"; -constexpr char kDmlcWorkerNum[] = "DMLC_NUM_WORKER"; -constexpr char kDmlcRole[] = "DMLC_ROLE"; -constexpr char kDmlcSchedulerHost[] = "DMLC_PS_ROOT_URI"; -constexpr char kDmlcSchedulerPort[] = "DMLC_PS_ROOT_PORT"; - -constexpr char kCommTypeOfIBVerbs[] = "ibverbs"; -constexpr char kCommTypeOfTCP[] = "zmq"; -constexpr char kRoleOfPServer[] = "server"; -constexpr char kRoleOfWorker[] = "worker"; -constexpr char kRoleOfScheduler[] = "scheduler"; - -constexpr char kLearningRate[] = "learning_rate"; -constexpr char kMomentum[] = "momentum"; - -constexpr char kApplyMomentum[] = "ApplyMomentum"; -constexpr char kSparseAdam[] = "Adam"; -constexpr char kSparseLazyAdam[] = "LazyAdam"; -constexpr char kSparseFtrl[] = "Ftrl"; -constexpr char kApplyMomentumOp[] = "Momentum"; -constexpr char kSparseAdamOp[] = "Adam"; -constexpr char kSparseLazyAdamOp[] = "LazyAdam"; -constexpr char kSparseFtrlOp[] = "FTRL"; - -constexpr int64_t kInitWeightsCmd = 10; -constexpr int64_t kInitWeightToOptimIdCmd = 11; -constexpr int64_t kInitOptimInputsShapeCmd = 12; -constexpr int64_t kInitKeyToPushNodeIdCmd = 13; -constexpr int64_t kInitEmbeddingsCmd = 20; -constexpr int64_t kUpdateEmbeddingsCmd = 21; -constexpr int64_t kCheckReadyForPushCmd = 25; -constexpr int64_t kCheckReadyForPullCmd = 26; -constexpr int64_t kEmbeddingLookupCmd = 30; -constexpr int64_t kFinalizeCmd = 40; - -constexpr size_t kInvalidKey = UINT64_MAX; -constexpr int64_t kInvalidID = -1; - -using DataPtr = std::shared_ptr; -using VectorPtr = std::shared_ptr>; -using Key = ::ps::Key; -using Keys = ::ps::SArray; -using Values = ::ps::SArray; -using ValuesPtr = std::shared_ptr; -using Weight = ::ps::SArray; -using Grad = ::ps::SArray; -using LookupIds = ::ps::SArray; -using Lengths = ::ps::SArray; -using WeightPtr = std::shared_ptr; -using GradPtr = std::shared_ptr; -using InputsShape = std::vector>>; -using InputsShapePtr = std::shared_ptr>>>; - -constexpr size_t INDEX_NOT_SEND = UINT_MAX; -using OptimOriginIdx = std::map; -using OptimPSSendIdx = std::map; - -const OptimOriginIdx kMomentumOriginIdx = {{"weight", 0}, {"accum", 1}, {"lr", 2}, {"grad", 3}, {"momentum", 4}}; -const OptimPSSendIdx kMomentumPSSendIdx = { - {"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"lr", 0}, {"grad", 1}, {"momentum", 2}}; - -const OptimOriginIdx kSparseAdamOriginIdx = {{"weight", 0}, {"m", 1}, {"v", 2}, {"beta1_power", 3}, - {"beta2_power", 4}, {"lr", 5}, {"beta1", 6}, {"beta2", 7}, - {"eps", 8}, {"grad", 9}, {"indices", 10}}; -const OptimPSSendIdx kSparseAdamPSSendIdx = {{"weight", INDEX_NOT_SEND}, - {"m", INDEX_NOT_SEND}, - {"v", INDEX_NOT_SEND}, - {"beta1_power", 0}, - {"beta2_power", 1}, - {"lr", 2}, - {"beta1", 3}, - {"beta2", 4}, - {"eps", 5}, - {"grad", 6}, - {"indices", 7}}; - -const OptimOriginIdx kSparseFtrlOriginIdx = {{"weight", 0}, {"accum", 1}, {"linear", 2}, {"grad", 3}, {"indices", 4}}; -const OptimPSSendIdx kSparseFtrlPSSendIdx = { - {"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"linear", INDEX_NOT_SEND}, {"grad", 0}, {"indices", 1}}; - -const std::map kOptimToOriginIdx = {{kApplyMomentum, kMomentumOriginIdx}, - {kSparseAdam, kSparseAdamOriginIdx}, - {kSparseLazyAdam, kSparseAdamOriginIdx}, - {kSparseFtrl, kSparseFtrlOriginIdx}}; -const std::map kOptimToPSSendIdx = {{kApplyMomentum, kMomentumPSSendIdx}, - {kSparseAdam, kSparseAdamPSSendIdx}, - {kSparseLazyAdam, kSparseAdamPSSendIdx}, - {kSparseFtrl, kSparseFtrlPSSendIdx}}; - -#define EXC_IF_VEC_IDX_OOB(vec, idx) \ - { \ - size_t vec_size = vec.size(); \ - if (idx >= vec_size) { \ - MS_LOG(EXCEPTION) << "Vector " << #vec << " size is " << vec_size << ". So index " << idx \ - << " is out of bound."; \ - } \ - } -} // namespace ps -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_COMMON_H_ diff --git a/mindspore/ccsrc/ps/internal/constants.h b/mindspore/ccsrc/ps/constants.h similarity index 96% rename from mindspore/ccsrc/ps/internal/constants.h rename to mindspore/ccsrc/ps/constants.h index 9fd6905740..59e4284587 100644 --- a/mindspore/ccsrc/ps/internal/constants.h +++ b/mindspore/ccsrc/ps/constants.h @@ -14,10 +14,11 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_ -#define MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_ +#ifndef MINDSPORE_CCSRC_PS_CONSTANTS_H_ +#define MINDSPORE_CCSRC_PS_CONSTANTS_H_ + +#include -#include #include #include #include @@ -26,8 +27,6 @@ namespace mindspore { namespace ps { -namespace internal { - constexpr char kEnvCommType[] = "MS_COMM_TYPE"; constexpr char kEnvInterface[] = "MS_INTERFACE"; constexpr char kEnvPServerNum[] = "MS_SERVER_NUM"; @@ -127,7 +126,6 @@ const std::map kOptimToPSSendIdx = {{kApplyMomentum << " is out of bound."; \ } \ } -} // namespace internal } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_ +#endif // MINDSPORE_CCSRC_PS_CONSTANTS_H_ diff --git a/mindspore/ccsrc/ps/core/cluster_metadata.cc b/mindspore/ccsrc/ps/core/cluster_metadata.cc index ccbe6b898c..7f704fab63 100644 --- a/mindspore/ccsrc/ps/core/cluster_metadata.cc +++ b/mindspore/ccsrc/ps/core/cluster_metadata.cc @@ -39,9 +39,9 @@ void ClusterMetadata::Init(const uint32_t &worker_num, const uint32_t &server_nu scheduler_port_ = scheduler_port; } -uint32_t ClusterMetadata::worker_num() { return worker_num_; } +uint32_t ClusterMetadata::total_worker_num() { return worker_num_; } -uint32_t ClusterMetadata::server_num() { return server_num_; } +uint32_t ClusterMetadata::total_server_num() { return server_num_; } uint32_t ClusterMetadata::heartbeat_interval() { return heartbeat_interval_; } diff --git a/mindspore/ccsrc/ps/core/cluster_metadata.h b/mindspore/ccsrc/ps/core/cluster_metadata.h index 47b7651b01..f27479af94 100644 --- a/mindspore/ccsrc/ps/core/cluster_metadata.h +++ b/mindspore/ccsrc/ps/core/cluster_metadata.h @@ -37,8 +37,8 @@ class ClusterMetadata { void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, const uint16_t &scheduler_port); - uint32_t worker_num(); - uint32_t server_num(); + uint32_t total_worker_num(); + uint32_t total_server_num(); uint32_t heartbeat_interval(); void set_heartbeat_interval(const uint32_t &heartbeat_interval); std::string scheduler_host(); diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index aeac337316..5bac87180c 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -122,9 +122,9 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) { } } bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) { - if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->server_num() - 1)) { + if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->total_server_num() - 1)) { return false; - } else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->worker_num() - 1)) { + } else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->total_worker_num() - 1)) { return false; } return true; diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index 57b45ccebd..197f8a984f 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace ps { namespace core { void NodeManager::InitNodeNum() { - total_node_num_ = ClusterMetadata::instance()->server_num() + ClusterMetadata::instance()->worker_num(); + total_node_num_ = ClusterMetadata::instance()->total_server_num() + ClusterMetadata::instance()->total_worker_num(); } int NodeManager::NextRankId(const RegisterMessage ®ister_message) { diff --git a/mindspore/ccsrc/ps/internal/parameter_server.h b/mindspore/ccsrc/ps/internal/parameter_server.h deleted file mode 100644 index 6fb25c7dc7..0000000000 --- a/mindspore/ccsrc/ps/internal/parameter_server.h +++ /dev/null @@ -1,179 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ -#define MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "backend/session/session_basic.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "backend/session/session_factory.h" -#include "ps/optimizer_info.h" -#include "ps/optimizer_info_builder.h" -#include "ps/ps_context.h" -#include "runtime/device/cpu/kernel_select_cpu.h" -#include "utils/ms_context.h" -#include "backend/kernel_compiler/kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" -#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" -#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" -#include "ps/ps_cache/ps_data/ps_data_prefetch.h" -#include "ps/random_normal/random_normal.h" - -#include "ps/internal/constants.h" -#include "ps/util.h" -#include "ps/embedding_table_shard_metadata.h" -#include "utils/log_adapter.h" -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" -#include "ps/core/server_node.h" - -namespace mindspore { -namespace ps { -namespace internal { - -class ParameterServer { - public: - static ParameterServer &GetInstance() { - static ParameterServer instance; - return instance; - } - - void Run(const FuncGraphPtr &func_graph); - - private: - ParameterServer() - : pserver_num_(0), - worker_num_(0), - rank_id_(0), - grad_accum_count_(0), - handler_(nullptr), - func_graph_(nullptr), - sess_(nullptr), - running_(true), - thread_(nullptr) {} - ~ParameterServer() = default; - ParameterServer(const ParameterServer &) = delete; - ParameterServer &operator=(const ParameterServer &) = delete; - - class ServerHandler { - public: - explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} - ~ServerHandler() = default; - void Init(); - void operator()(std::shared_ptr conn, std::shared_ptr meta, DataPtr data, - size_t size); - void HandlePushReq(DataPtr data, size_t size, VectorPtr res); - void HandlePullReq(DataPtr data, size_t size, VectorPtr res); - void HandleInitWeights(DataPtr data, size_t size, VectorPtr res); - void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res); - void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res); - void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res); - void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res); - void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res); - void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res); - void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res); - void HandleFinalize(DataPtr data, size_t size, VectorPtr res); - - private: - ParameterServer *ps_; - typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res); - std::unordered_map handlers_; - std::unordered_map init_weights_; - std::unordered_map init_weight_to_optim_; - std::unordered_map init_optim_info_; - }; - - bool Init(const FuncGraphPtr &func_graph); - void InitOptimInfoBuilders(); - void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id); - void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); - void InitWeight(const Key &key, const WeightPtr &weight); - void InitGrad(const Key &key, const GradPtr &grad); - void InitEmbeddingTable(const Key &key, - const std::shared_ptr>>> &shapes, - const ParamInitInfo ¶m_init_info); - bool HasWeight(const Key &key); - void Finalize(); - void UpdateWeights(); - void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); - WeightPtr weight(const Key &key); - void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res); - void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); - bool ReadyForUpdateWeights(); - bool ReadyForPush(const Key &key); - bool ReadyForPull(const Key &key); - void ResetGradAccumCount(); - const CNodePtr GetCNode(const std::string &name) const; - std::mutex &mutex(); - void GetEmbeddingTableParamPtr(); - void SyncEmbeddingTables(); - - size_t pserver_num_; - size_t worker_num_; - size_t rank_id_; - size_t grad_accum_count_; - std::unique_ptr handler_; - FuncGraphPtr func_graph_; - std::shared_ptr sess_; - bool running_; - - std::unordered_map> optimizers_; - std::unordered_map optim_inputs_shape_; - std::unordered_map original_optim_inputs_shape_; - std::unordered_map> optim_infos_; - std::unordered_map> optim_info_builders_; - std::unordered_map weight_key_to_optims_; - std::unordered_map weight_key_to_optim_op_; - std::unordered_map weights_; - std::unordered_map is_embedding_; - std::unordered_map grads_; - std::unordered_map grads_accum_counter_; - std::unordered_map> embedding_lookup_ops_; - std::unordered_map tokens_; - - std::mutex mutex_; - std::condition_variable apply_grads_cv_; - - std::unique_ptr thread_; - core::ServerNode server_node_; - std::map embedding_tables_; - - friend class ServerHandler; -}; -} // namespace internal -} // namespace ps -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ diff --git a/mindspore/ccsrc/ps/internal/worker.h b/mindspore/ccsrc/ps/internal/worker.h deleted file mode 100644 index 7298afd4a9..0000000000 --- a/mindspore/ccsrc/ps/internal/worker.h +++ /dev/null @@ -1,157 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_ -#define MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "utils/log_adapter.h" -#include "ir/tensor.h" -#include "ps/util.h" -#include "ps/internal/constants.h" -#include "utils/shape_utils.h" -#include "ps/ps_cache/ps_data/ps_data_prefetch.h" -#include "ps/core/worker_node.h" -#include "ps/embedding_table_shard_metadata.h" -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" -#include "ps/ps_context.h" - -namespace mindspore { -namespace ps { -namespace internal { - -class Worker { - public: - static Worker &GetInstance() { - static Worker instance; - return instance; - } - using Callback = std::function; - using PartitionEmbeddingMessages = std::vector>; - using PartitionKVMessages = std::vector>; - - using EmbeddingPartitioner = std::function &attrs)>; - using KVPartitioner = - std::function &attrs)>; - - void Run(); - void Push(const std::vector &keys, std::vector addrs, const ShapeVector &sizes); - void Pull(const size_t key, void *dev_addr, const size_t size); - size_t SetParamKey(const std::string ¶m_name); - size_t GetParamKey(const std::string ¶m_name); - void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); - bool GetParamInitInServer(const std::string ¶m_name); - void SetKeyOptimId(size_t key, const std::string &optimizer_name); - void SetOptimInputShapes(size_t key, const ShapeVector &shape); - void AddEmbeddingTable(const Key &key, const size_t &row_count); - void InitPSEmbeddingTable(const size_t &key, const std::vector &input_shape, - const std::vector &indices_shape, const std::vector &output_shape); - void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); - void DoPSEmbeddingLookup(const Key &key, const std::vector &lookup_ids, std::vector *lookup_result, - int64_t cmd); - void UpdateEmbeddingTable(const std::vector &keys, const std::vector &lookup_ids, - const std::vector &vals); - - bool running() { return running_; } - void Finalize(); - - private: - Worker() : running_(false), key_cnt_(0) {} - ~Worker() = default; - Worker(const Worker &) = delete; - Worker &operator=(const Worker &) = delete; - - void Initialize(); - bool IsKeyInit(const size_t key); - void AddKeyToServerId(const Key &key); - void AddKeyByHashMod(const Key &key); - void InitPSOptimId(const size_t param_key); - void InitPSOptimInputShapes(const size_t key); - void InitPSParamData(const std::vector &keys, void *origin_addr, size_t size); - bool IsReadyForPush(const Key &key); - bool IsReadyForPull(const Key &key); - void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set &distinct_ids, - const std::vector> &indice_to_grads, const int *all_indice, - const size_t segment_size, float *gradient, int *indices); - void BuildSparseValue(const std::vector &lengths, const size_t grad_index, const size_t indice_index, - const float *original_data, const float *grads, int *indices, std::vector *reduced_data); - - void PushData(const std::vector &keys, const std::vector &vals, const std::vector &lens = {}, - int command = 0, int64_t priority = 0); - void PushSparseData(const std::vector &keys, const std::vector &vals, const std::vector &lens, - size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); - void PullData(const std::vector &keys, std::vector *vals, std::vector *lens = nullptr, int cmd = 0, - int64_t priority = 0); - - void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, - const std::map &attrs); - - void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition, - const std::map &attrs); - void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition, - const std::map &attrs); - void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector> *partition, - const std::map &attrs); - void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition, - const std::map &attrs); - void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition, - const std::map &attrs); - void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner, - const std::map &attrs); - void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner, - const std::map &attrs, std::vector *vals, std::vector *lens); - - int64_t server_num_; - bool running_; - std::mutex running_mutex_; - size_t key_cnt_; - std::map param_to_key_; - std::map init_keys_; - std::map key_to_optimId_; - std::map> key_to_optim_shapes_; - std::map param_to_init_in_server_; - core::WorkerNode worker_node_; - - EmbeddingPartitioner lookup_partitioner_; - KVPartitioner sparse_partitioner_; - KVPartitioner round_robin_partitioner_; - KVPartitioner worker_init_embedding_partitioner_; - KVPartitioner update_embedding_partitioner_; - KVPartitioner broadcast_partitioner_; - std::unordered_map key_to_server_id_; - std::unordered_map embedding_row_cnt_; - - std::unordered_map>> embedding_table_ranges_; -}; - -static Worker &worker = Worker::GetInstance(); -} // namespace internal -} // namespace ps -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_ diff --git a/mindspore/ccsrc/ps/optimizer_info.cc b/mindspore/ccsrc/ps/optimizer_info.cc index ab7553e556..03eb594753 100644 --- a/mindspore/ccsrc/ps/optimizer_info.cc +++ b/mindspore/ccsrc/ps/optimizer_info.cc @@ -84,7 +84,7 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { for (size_t i = 0; i < grad_index; i++) { grad_offset += lengths[i]; } - float *grad_data = values.data() + grad_offset; + float *grad_data = const_cast(values.data()) + grad_offset; CHECK_EQ(size, static_cast(lengths[grad_index])); for (size_t i = 0; i < size; i++) { @@ -121,7 +121,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { for (size_t i = 0; i < grad_index; i++) { grad_offset += lengths[i]; } - float *incr_grad_data = values.data() + grad_offset; + float *incr_grad_data = const_cast(values.data()) + grad_offset; MS_EXCEPTION_IF_NULL(incr_grad_data); size_t incr_grad_size = lengths[grad_index] * sizeof(float); @@ -148,7 +148,11 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { for (size_t i = 0; i < indices_index; i++) { indice_offset += lengths[i]; } - int *incr_indice_data = reinterpret_cast(values.data()) + indice_offset; + + void *incr_indice_data_temp = const_cast(values.data()) + indice_offset; + + int *incr_indice_data = reinterpret_cast(incr_indice_data_temp); + MS_EXCEPTION_IF_NULL(incr_indice_data); size_t incr_indice_size = lengths[indices_index]; size_t incr_indice_data_size = incr_indice_size * sizeof(int); @@ -259,7 +263,7 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr } void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) { - UpdateOptimInputValue(kApplyMomentum, "lr", values.data(), lens); + UpdateOptimInputValue(kApplyMomentum, "lr", const_cast(values.data()), lens); } const size_t SparseOptimInfo::indice_size() const { return indices_offset_; } @@ -303,12 +307,12 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address } void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { - UpdateOptimInputValue(kSparseAdam, "beta1_power", values.data(), lens); - UpdateOptimInputValue(kSparseAdam, "beta2_power", values.data(), lens); - UpdateOptimInputValue(kSparseAdam, "lr", values.data(), lens); - UpdateOptimInputValue(kSparseAdam, "beta1", values.data(), lens); - UpdateOptimInputValue(kSparseAdam, "beta2", values.data(), lens); - UpdateOptimInputValue(kSparseAdam, "eps", values.data(), lens); + UpdateOptimInputValue(kSparseAdam, "beta1_power", const_cast(values.data()), lens); + UpdateOptimInputValue(kSparseAdam, "beta2_power", const_cast(values.data()), lens); + UpdateOptimInputValue(kSparseAdam, "lr", const_cast(values.data()), lens); + UpdateOptimInputValue(kSparseAdam, "beta1", const_cast(values.data()), lens); + UpdateOptimInputValue(kSparseAdam, "beta2", const_cast(values.data()), lens); + UpdateOptimInputValue(kSparseAdam, "eps", const_cast(values.data()), lens); } const AddressPtr &SparseAdamOptimInfo::gradient() { diff --git a/mindspore/ccsrc/ps/optimizer_info.h b/mindspore/ccsrc/ps/optimizer_info.h index b91ea82d09..71d5b9ae2c 100644 --- a/mindspore/ccsrc/ps/optimizer_info.h +++ b/mindspore/ccsrc/ps/optimizer_info.h @@ -20,7 +20,7 @@ #include #include #include "backend/kernel_compiler/kernel.h" -#include "ps/common.h" +#include "ps/constants.h" namespace mindspore { namespace ps { diff --git a/mindspore/ccsrc/ps/optimizer_info_builder.cc b/mindspore/ccsrc/ps/optimizer_info_builder.cc index acc443f841..40329808c7 100644 --- a/mindspore/ccsrc/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/ps/optimizer_info_builder.cc @@ -129,9 +129,9 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co return nullptr; } - AddressPtr learning_rate = GenInputAddrPtr(kApplyMomentum, "lr", values.data(), lens); - AddressPtr gradient = GenInputAddrPtr(kApplyMomentum, "grad", values.data(), lens); - AddressPtr momentum = GenInputAddrPtr(kApplyMomentum, "momentum", values.data(), lens); + AddressPtr learning_rate = GenInputAddrPtr(kApplyMomentum, "lr", const_cast(values.data()), lens); + AddressPtr gradient = GenInputAddrPtr(kApplyMomentum, "grad", const_cast(values.data()), lens); + AddressPtr momentum = GenInputAddrPtr(kApplyMomentum, "momentum", const_cast(values.data()), lens); return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum); } @@ -172,14 +172,15 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, return nullptr; } - AddressPtr beta1_power = GenInputAddrPtr(kSparseAdam, "beta1_power", values.data(), lens); - AddressPtr beta2_power = GenInputAddrPtr(kSparseAdam, "beta2_power", values.data(), lens); - AddressPtr learning_rate = GenInputAddrPtr(kSparseAdam, "lr", values.data(), lens); - AddressPtr beta1 = GenInputAddrPtr(kSparseAdam, "beta1", values.data(), lens); - AddressPtr beta2 = GenInputAddrPtr(kSparseAdam, "beta2", values.data(), lens); - AddressPtr epsilon = GenInputAddrPtr(kSparseAdam, "eps", values.data(), lens); - AddressPtr grad = GenInputAddrPtr(kSparseAdam, "grad", values.data(), lens, inputs_shape); - AddressPtr indices = GenInputAddrPtr(kSparseAdam, "indices", values.data(), lens, inputs_shape); + AddressPtr beta1_power = GenInputAddrPtr(kSparseAdam, "beta1_power", const_cast(values.data()), lens); + AddressPtr beta2_power = GenInputAddrPtr(kSparseAdam, "beta2_power", const_cast(values.data()), lens); + AddressPtr learning_rate = GenInputAddrPtr(kSparseAdam, "lr", const_cast(values.data()), lens); + AddressPtr beta1 = GenInputAddrPtr(kSparseAdam, "beta1", const_cast(values.data()), lens); + AddressPtr beta2 = GenInputAddrPtr(kSparseAdam, "beta2", const_cast(values.data()), lens); + AddressPtr epsilon = GenInputAddrPtr(kSparseAdam, "eps", const_cast(values.data()), lens); + AddressPtr grad = GenInputAddrPtr(kSparseAdam, "grad", const_cast(values.data()), lens, inputs_shape); + AddressPtr indices = + GenInputAddrPtr(kSparseAdam, "indices", const_cast(values.data()), lens, inputs_shape); return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, grad, indices, sharded); } @@ -218,8 +219,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, } linear->size = weight->size() * sizeof(float); - AddressPtr grad = GenInputAddrPtr(kSparseFtrl, "grad", values.data(), lens, inputs_shape); - AddressPtr indices = GenInputAddrPtr(kSparseFtrl, "indices", values.data(), lens, inputs_shape); + AddressPtr grad = GenInputAddrPtr(kSparseFtrl, "grad", const_cast(values.data()), lens, inputs_shape); + AddressPtr indices = + GenInputAddrPtr(kSparseFtrl, "indices", const_cast(values.data()), lens, inputs_shape); return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded); } } // namespace ps diff --git a/mindspore/ccsrc/ps/internal/parameter_server.cc b/mindspore/ccsrc/ps/parameter_server.cc similarity index 96% rename from mindspore/ccsrc/ps/internal/parameter_server.cc rename to mindspore/ccsrc/ps/parameter_server.cc index acd57c173e..71d330be47 100644 --- a/mindspore/ccsrc/ps/internal/parameter_server.cc +++ b/mindspore/ccsrc/ps/parameter_server.cc @@ -14,12 +14,10 @@ * limitations under the License. */ -#include "ps/internal/parameter_server.h" +#include "ps/parameter_server.h" namespace mindspore { namespace ps { -namespace internal { - void ParameterServer::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; @@ -44,8 +42,8 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { } bool ParameterServer::Init(const FuncGraphPtr &func_graph) { - pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); - worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); + pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); + worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); func_graph_ = func_graph; handler_.reset(new ServerHandler(this)); handler_->Init(); @@ -257,12 +255,21 @@ void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Le std::shared_ptr optim_info = optim_infos_[key]; // Create or update the optimizer info - std::shared_ptr pserver_kernel = optimizers_[key]; - if (pserver_kernel == nullptr) { - MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; + if (optim_info == nullptr) { + const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; + std::shared_ptr pserver_kernel = optimizers_[key]; + if (pserver_kernel == nullptr) { + MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; + } + MS_EXCEPTION_IF_NULL(pserver_kernel); + OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths, + optim_inputs_shape_[key], worker_num_, is_embedding_[key]); + optim_info.reset(optim); + optim_infos_[key] = optim_info; + } else { + optim_info->Update(values, lengths); + optim_info->Accumulate(values, lengths); } - MS_EXCEPTION_IF_NULL(pserver_kernel); - optim_infos_[key] = optim_info; } grads_accum_counter_[key] += 1; @@ -373,7 +380,7 @@ inline bool ParameterServer::ReadyForPush(const Key &key) { MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send " "kInitWeightsCmd command. 2.The Server failed to initialize weights."; } - MS_LOG(INFO) << "the grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size() + MS_LOG(INFO) << "The grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size() << " the token:" << (tokens_[key] <= 0); return grad_accum_count_ < weights_.size() && tokens_[key] <= 0; } @@ -544,11 +551,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size for (int i = 0; i < key_num; i++) { Key key = input.keys()[i]; size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i]; - MS_LOG(DEBUG) << "The data len:" << data_len; if (!ps_->HasWeight(key)) { WeightPtr weight_ptr = std::make_shared>(data_ptr + pos, data_ptr + (pos + data_len)); - MS_LOG(DEBUG) << "The weight ptr:" << *weight_ptr; MS_EXCEPTION_IF_NULL(weight_ptr); ps_->InitWeight(key, weight_ptr); @@ -637,7 +642,7 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_ input.ParseFromArray(data.get(), size); const Key &key = input.keys()[0]; bool ready = ps_->ReadyForPush(key); - MS_LOG(INFO) << "the ready is:" << ready; + MS_LOG(INFO) << "The ready is:" << ready; KVMessage res_data; res_data.add_keys(key); res_data.add_values(ready); @@ -671,7 +676,6 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t EmbeddingTableLookup input; input.ParseFromArray(data.get(), size); const Key &key = input.key(); - MS_LOG(DEBUG) << "The key is:" << key; KVMessage res_data; std::vector keys = {input.keys().begin(), input.keys().end()}; @@ -701,6 +705,5 @@ void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, V MS_EXCEPTION_IF_NULL(res); ps_->Finalize(); } -} // namespace internal } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/parameter_server.h b/mindspore/ccsrc/ps/parameter_server.h index e28fcb4514..fdccac3030 100644 --- a/mindspore/ccsrc/ps/parameter_server.h +++ b/mindspore/ccsrc/ps/parameter_server.h @@ -1,847 +1,177 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ -#define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "backend/session/session_basic.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "backend/session/session_factory.h" -#include "ps/common.h" -#include "ps/optimizer_info.h" -#include "ps/optimizer_info_builder.h" -#include "ps/util.h" -#include "ps/ps_context.h" -#include "runtime/device/cpu/kernel_select_cpu.h" -#include "utils/ms_context.h" -#include "backend/kernel_compiler/kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" -#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" -#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" -#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" -#include "ps/ps_cache/ps_data/ps_data_prefetch.h" -#include "ps/random_normal/random_normal.h" - -namespace mindspore { -namespace ps { -using mindspore::kernel::ps::PServerKernel; -using AnfAlgo = session::AnfRuntimeAlgorithm; -template -class ParameterServer { - public: - static ParameterServer &GetInstance() { - static ParameterServer instance; - return instance; - } - - void Run(const FuncGraphPtr &func_graph); - - private: - ParameterServer() - : pserver_num_(0), - worker_num_(0), - rank_id_(0), - grad_accum_count_(0), - ps_(new ::ps::KVServer(0)), - handler_(nullptr), - func_graph_(nullptr), - sess_(nullptr), - running_(true), - thread_(nullptr) {} - ~ParameterServer() = default; - ParameterServer(const ParameterServer &) = delete; - ParameterServer &operator=(const ParameterServer &) = delete; - - class ServerHandler { - public: - explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} - ~ServerHandler() = default; - void Init(); - void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); - - private: - void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleInitWeights(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVPairs *res); - void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - - ParameterServer *ps_; - typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVPairs *res); - std::unordered_map handlers_; - std::unordered_map init_weights_; - std::unordered_map init_weight_to_optim_; - std::unordered_map init_optim_info_; - }; - - bool Init(const FuncGraphPtr &func_graph); - void InitOptimInfoBuilders(); - void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id); - void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); - void InitWeight(const Key &key, const WeightPtr &weight); - void InitGrad(const Key &key, const GradPtr &grad); - void InitEmbeddingTable(const Key &key, - const std::shared_ptr>>> &shapes, - const ParamInitInfo ¶m_init_info); - bool HasWeight(const Key &key); - void Finalize(); - void UpdateWeights(); - void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); - WeightPtr weight(const Key &key); - void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); - void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); - bool ReadyForUpdateWeights(); - bool ReadyForPush(const Key &key); - bool ReadyForPull(const Key &key); - void ResetGradAccumCount(); - const CNodePtr GetCNode(const std::string &name) const; - std::mutex &mutex(); - void GetEmbeddingTableParamPtr(); - void SyncEmbeddingTables(); - - size_t pserver_num_; - size_t worker_num_; - size_t rank_id_; - size_t grad_accum_count_; - std::unique_ptr<::ps::KVServer> ps_; - std::unique_ptr handler_; - FuncGraphPtr func_graph_; - std::shared_ptr sess_; - bool running_; - - std::unordered_map> optimizers_; - std::unordered_map optim_inputs_shape_; - std::unordered_map original_optim_inputs_shape_; - std::unordered_map> optim_infos_; - std::unordered_map> optim_info_builders_; - std::unordered_map weight_key_to_optims_; - std::unordered_map weight_key_to_optim_op_; - std::unordered_map weights_; - std::unordered_map is_embedding_; - std::unordered_map grads_; - std::unordered_map grads_accum_counter_; - std::unordered_map> embedding_lookup_ops_; - std::unordered_map tokens_; - - std::mutex mutex_; - std::condition_variable apply_grads_cv_; - - std::unique_ptr thread_; - std::map embedding_tables_; - - friend class ServerHandler; -}; - -class FuncGraph; -template -void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVServer *server) { - MS_EXCEPTION_IF_NULL(server); - ::ps::KVPairs res; - if (handlers_.count(req_meta.cmd) > 0) { - auto &handler_ptr = handlers_[req_meta.cmd]; - (this->*handler_ptr)(req_meta, req_data, &res); - } else if (req_meta.push) { - HandlePushReq(req_meta, req_data, &res); - } else { - HandlePullReq(req_meta, req_data, &res); - } - server->Response(req_meta, res); -} - -template -void ParameterServer::ServerHandler::Init() { - handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights; - handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId; - handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape; - handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings; - handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; - handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; - handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; - handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings; - handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; -} - -template -void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - MS_EXCEPTION_IF_NULL(res); - ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); -} - -template -void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - MS_EXCEPTION_IF_NULL(res); - res->keys = req_data.keys; - ::ps::Key key = req_data.keys[0]; - res->vals = *(ps_->weight(key)); -} - -template -void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { - std::unique_lock lock(ps_->mutex()); - MS_EXCEPTION_IF_NULL(res); - size_t key_num = req_data.keys.size(); - T *data_ptr = req_data.vals.data(); - size_t pos = 0; - for (size_t i = 0; i < key_num; i++) { - Key key = req_data.keys[i]; - size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; - - if (!ps_->HasWeight(key)) { - WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); - MS_EXCEPTION_IF_NULL(weight_ptr); - weight_ptr->CopyFrom(data_ptr + pos, data_len); - ps_->InitWeight(key, weight_ptr); - - GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); - MS_EXCEPTION_IF_NULL(grad_ptr); - ps_->InitGrad(key, grad_ptr); - } - pos += data_len; - } -} - -template -void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - std::unique_lock lock(ps_->mutex()); - MS_EXCEPTION_IF_NULL(res); - size_t key_num = req_data.keys.size(); - for (size_t i = 0; i < key_num; i++) { - Key key = req_data.keys[i]; - T val = req_data.vals[i]; - if (init_weight_to_optim_[key]) { - continue; - } else { - init_weight_to_optim_[key] = true; - } - ps_->InitWeightKeyToOptims(key, val); - } -} - -template -void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { - std::unique_lock lock(ps_->mutex()); - MS_EXCEPTION_IF_NULL(res); - const Key &key = req_data.keys[0]; - if (init_optim_info_[key]) { - return; - } else { - init_optim_info_[key] = true; - } - ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); -} - -template -void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { - std::unique_lock lock(ps_->mutex()); - MS_EXCEPTION_IF_NULL(res); - const Key &key = req_data.keys[0]; - MS_LOG(INFO) << "Initializing embedding table for key:" << key; - std::shared_ptr>>> shapes = - std::make_shared>>>(); - MS_EXCEPTION_IF_NULL(shapes); - std::shared_ptr> input_shape = std::make_shared>(); - MS_EXCEPTION_IF_NULL(input_shape); - std::shared_ptr> indices_shape = std::make_shared>(); - MS_EXCEPTION_IF_NULL(indices_shape); - std::shared_ptr> output_shape = std::make_shared>(); - MS_EXCEPTION_IF_NULL(output_shape); - shapes->push_back(input_shape); - shapes->push_back(indices_shape); - shapes->push_back(output_shape); - - const Lengths &lens = req_data.lens; - size_t index = 0; - for (int64_t i = 0; i < lens[0]; i++) { - input_shape->push_back(static_cast(req_data.vals[index++])); - } - for (int64_t j = 0; j < lens[1]; j++) { - indices_shape->push_back(static_cast(req_data.vals[index++])); - } - for (int64_t k = 0; k < lens[2]; k++) { - output_shape->push_back(static_cast(req_data.vals[index++])); - } - ParamInitInfo param_init_info; - if (ps::PsDataPrefetch::GetInstance().cache_enable()) { - param_init_info.param_type_ = static_cast(lens[3]); - if (param_init_info.param_type_ == kWeight) { - param_init_info.global_seed_ = static_cast(lens[4]); - param_init_info.op_seed_ = static_cast(lens[5]); - } else if (param_init_info.param_type_ == kAccumulation) { - param_init_info.init_val_ = req_data.vals[index]; - } - } - ps_->InitEmbeddingTable(key, shapes, param_init_info); -} - -template -void ParameterServer::ServerHandler::HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - MS_EXCEPTION_IF_NULL(res); - const Key &key = req_data.keys[0]; - bool ready = ps_->ReadyForPush(key); - res->keys.push_back(key); - res->vals.push_back(ready); -} - -template -void ParameterServer::ServerHandler::HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - MS_EXCEPTION_IF_NULL(res); - const Key &key = req_data.keys[0]; - bool ready = ps_->ReadyForPull(key); - res->keys.push_back(key); - res->vals.push_back(ready); -} - -template -void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { - MS_EXCEPTION_IF_NULL(res); - const Key &key = req_data.keys[0]; - for (size_t i = 1; i < req_data.keys.size(); i++) { - res->keys.push_back(req_data.keys[i]); - } - ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); -} - -template -void ParameterServer::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - std::unique_lock lock(ps_->mutex()); - MS_EXCEPTION_IF_NULL(res); - const Key &key = req_data.keys[0]; - const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size()); - const Values &update_vals = req_data.vals; - ps_->UpdateEmbeddings(key, lookup_ids, update_vals); -} - -template -void ParameterServer::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - MS_EXCEPTION_IF_NULL(res); - ps_->Finalize(); -} - -template -bool ParameterServer::Init(const FuncGraphPtr &func_graph) { - pserver_num_ = ::ps::NumServers(); - worker_num_ = ::ps::NumWorkers(); - func_graph_ = func_graph; - rank_id_ = ::ps::MyRank(); - handler_.reset(new ServerHandler(this)); - handler_->Init(); - - InitOptimInfoBuilders(); - ps_->set_request_handle(*handler_); - thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); - GetEmbeddingTableParamPtr(); - return true; -} - -template -void ParameterServer::InitOptimInfoBuilders() { - std::shared_ptr momentum_info_builder = std::make_shared(worker_num_); - std::shared_ptr sparse_adam_info_builder = - std::make_shared(worker_num_); - std::shared_ptr sparse_ftrl_info_builder = - std::make_shared(worker_num_); - optim_info_builders_[kApplyMomentum] = momentum_info_builder; - optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; - optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; -} - -template -void ParameterServer::InitWeightKeyToOptims(const Key &key, const int64_t &optim_id) { - if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") { - return; - } - weight_key_to_optims_[key] = Util::optimizer_name(optim_id); - weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id); - MS_LOG(INFO) << "Initializing optimizer id for key:" << key << ", optimizer name:" << weight_key_to_optims_[key] - << ", optimizer op name:" << weight_key_to_optim_op_[key]; -} - -template -void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { - InputsShapePtr inputs_shape = std::make_shared(); - MS_EXCEPTION_IF_NULL(inputs_shape); - InputsShapePtr original_inputs_shape = std::make_shared(); - MS_EXCEPTION_IF_NULL(original_inputs_shape); - int64_t val_idx = 0; - const Key &key = keys[0]; - MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key; - if (optim_inputs_shape_.count(key) == 0) { - original_optim_inputs_shape_[key] = original_inputs_shape; - optim_inputs_shape_[key] = inputs_shape; - } - for (size_t i = 0; i < keys.size(); i++) { - auto shape = std::make_shared>(); - MS_EXCEPTION_IF_NULL(shape); - auto original_shape = std::make_shared>(); - MS_EXCEPTION_IF_NULL(original_shape); - inputs_shape->push_back(shape); - original_inputs_shape->push_back(original_shape); - - for (int64_t j = 0; j < lengths[i]; j++) { - shape->push_back(values[val_idx]); - original_shape->push_back(values[val_idx++]); - } - } - if (weight_key_to_optims_.count(key) > 0) { - const std::string &optim_name = weight_key_to_optims_[key]; - const std::string &optim_op_name = weight_key_to_optim_op_[key]; - if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) { - const CNodePtr cnode = GetCNode(optim_op_name); - MS_EXCEPTION_IF_NULL(cnode); - if (optim_name == kSparseAdam) { - std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_, worker_num_); - optimizer->InitKernel(cnode, optim_inputs_shape_[key]); - optimizers_[key] = optimizer; - } else if (optim_name == kSparseLazyAdam) { - std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_, worker_num_); - optimizer->InitKernel(cnode, optim_inputs_shape_[key]); - optimizers_[key] = optimizer; - } else if (optim_name == kApplyMomentum) { - std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_, worker_num_); - optimizer->InitKernel(cnode, optim_inputs_shape_[key]); - optimizers_[key] = optimizer; - } else if (optim_name == kSparseFtrl) { - std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_, worker_num_); - optimizer->InitKernel(cnode, optim_inputs_shape_[key]); - optimizers_[key] = optimizer; - } - } - } -} - -template -const CNodePtr ParameterServer::GetCNode(const std::string &name) const { - std::list cnodes = func_graph_->GetOrderedCnodes(); - for (CNodePtr cnode : cnodes) { - MS_EXCEPTION_IF_NULL(cnode); - std::string fullname = cnode->fullname_with_scope(); - if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) { - return cnode; - } - } - return nullptr; -} - -template -void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { - MS_EXCEPTION_IF_NULL(weight); - if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { - MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; - weights_[key] = weight; - tokens_[key] = 0; - is_embedding_[key] = false; - } -} - -template -void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { - MS_EXCEPTION_IF_NULL(grad); - if (grads_.count(key) == 0) { - grads_[key] = grad; - grads_accum_counter_[key] = 0; - } -} - -template -void ParameterServer::InitEmbeddingTable( - const Key &key, const std::shared_ptr>>> &shapes, - const ParamInitInfo ¶m_init_info) { - MS_EXCEPTION_IF_NULL(shapes); - if (weights_.count(key) == 0) { - std::shared_ptr lookup = - std::make_shared(rank_id_, pserver_num_, worker_num_); - lookup->InitKernel(shapes); - embedding_lookup_ops_[key] = lookup; - - // Init embedding weight - const std::vector &input_shapes = lookup->input_sizes(); - size_t total_dims = - std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies()); - WeightPtr embedding = std::make_shared(total_dims, 0); - MS_EXCEPTION_IF_NULL(embedding); - T *embedding_data = embedding->data(); - std::default_random_engine engine; - std::normal_distribution random(0, 0.01); - if (ps::PsDataPrefetch::GetInstance().cache_enable()) { - if (param_init_info.param_type_ == kWeight) { - InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data); - } else if (param_init_info.param_type_ == kAccumulation) { - for (size_t i = 0; i < total_dims; i++) { - embedding_data[i] = param_init_info.init_val_; - } - } - } else { - for (size_t i = 0; i < total_dims; i++) { - embedding_data[i] = random(engine); - } - } - weights_[key] = embedding; - tokens_[key] = 0; - is_embedding_[key] = true; - - grads_accum_counter_[key] = 0; - } -} - -template -bool ParameterServer::HasWeight(const Key &key) { - return (weights_.count(key) > 0 && !is_embedding_.count(key)); -} - -template -void ParameterServer::Finalize() { - running_ = false; - apply_grads_cv_.notify_one(); -} - -template -void ParameterServer::UpdateWeights() { - while (true) { - std::unique_lock lock(mutex_); - apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; }); - if (!running_) { - break; - } - - for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { - Key key = iter->first; - WeightPtr weight_ptr = iter->second; - - std::shared_ptr optimizer = nullptr; - if (weight_key_to_optims_.count(key) > 0) { - optimizer = optimizers_[key]; - } - MS_EXCEPTION_IF_NULL(optimizer); - - std::shared_ptr optim_info = optim_infos_[key]; - if (optim_info != nullptr) { - const std::vector &inputs = optim_info->inputs(); - const std::vector &workspaces = optim_info->workspaces(); - const std::vector &outputs = optim_info->outputs(); - - std::vector> shapes = {}; - std::vector indices_shape = {}; - indices_shape.emplace_back(optim_info->indice_size()); - shapes.push_back(indices_shape); - - if (original_optim_inputs_shape_.count(key) != 0) { - for (auto input_shapes : *(original_optim_inputs_shape_[key])) { - shapes.push_back(*input_shapes); - } - } - optimizer->ReInit(shapes); - optim_info->ComputeMean(shapes, worker_num_, pserver_num_, rank_id_); - optimizer->Execute(inputs, workspaces, outputs); - optim_info->Reset(); - } - if (!is_embedding_[key]) { - tokens_[key] = worker_num_; - } - } - ResetGradAccumCount(); - } -} - -template -void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { - std::unique_lock lock(mutex_); - const Key &key = keys[0]; - bool no_sparse_grad = values.size() == 1 && values[0] == -100; - if (!no_sparse_grad) { - std::shared_ptr optim_info = optim_infos_[key]; - - // Create or update the optimizer info - if (optim_info == nullptr) { - const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; - std::shared_ptr pserver_kernel = optimizers_[key]; - if (pserver_kernel == nullptr) { - MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; - } - MS_EXCEPTION_IF_NULL(pserver_kernel); - OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths, - optim_inputs_shape_[key], worker_num_, is_embedding_[key]); - optim_info.reset(optim); - optim_infos_[key] = optim_info; - } else { - optim_info->Update(values, lengths); - optim_info->Accumulate(values, lengths); - } - } - - grads_accum_counter_[key] += 1; - if (grads_accum_counter_[key] == worker_num_) { - grad_accum_count_++; - } - if (ReadyForUpdateWeights()) { - apply_grads_cv_.notify_one(); - } -} - -template -WeightPtr ParameterServer::weight(const Key &key) { - std::unique_lock lock(mutex_); - if (weights_.count(key) == 0) { - MS_LOG(EXCEPTION) << "Invalid weight key " << key; - } - WeightPtr weight_ptr = weights_[key]; - MS_EXCEPTION_IF_NULL(weight_ptr); - WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); - MS_EXCEPTION_IF_NULL(copy_weight_ptr); - copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); - tokens_[key] -= 1; - return copy_weight_ptr; -} - -template -void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res) { - std::unique_lock lock(mutex_); - MS_EXCEPTION_IF_NULL(res); - if (weights_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid embedding table key " << key; - return; - } - if (embedding_lookup_ops_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; - return; - } - WeightPtr table_ptr = weights_[key]; - MS_EXCEPTION_IF_NULL(table_ptr); - std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; - MS_EXCEPTION_IF_NULL(table_lookup_op); - - // Update shapes of lookup operator - std::vector> shapes = {}; - std::vector indices_shape = {}; - indices_shape.emplace_back(lookup_ids.size()); - shapes.push_back(indices_shape); - table_lookup_op->ReInit(shapes); - - const std::vector output_shapes = table_lookup_op->output_sizes(); - std::vector inputs; - AddressPtr embedding_table = std::make_shared(); - MS_EXCEPTION_IF_NULL(embedding_table); - AddressPtr indices = std::make_shared(); - MS_EXCEPTION_IF_NULL(indices); - inputs.push_back(embedding_table); - inputs.push_back(indices); - embedding_table->addr = table_ptr->data(); - embedding_table->size = table_ptr->size() * sizeof(T); - - std::unique_ptr tmp_ids(new int[lookup_ids.size()]); - MS_EXCEPTION_IF_NULL(tmp_ids); - for (size_t i = 0; i < lookup_ids.size(); i++) { - tmp_ids[i] = static_cast(lookup_ids[i]); - } - indices->addr = tmp_ids.get(); - indices->size = lookup_ids.size() * sizeof(int); - - std::vector workspaces; - std::vector outputs; - AddressPtr output = std::make_shared(); - MS_EXCEPTION_IF_NULL(output); - std::shared_ptr addr = std::make_shared(output_shapes[0] / sizeof(T), 0); - MS_EXCEPTION_IF_NULL(addr); - - output->addr = addr->data(); - output->size = output_shapes[0]; - outputs.push_back(output); - - table_lookup_op->Execute(inputs, workspaces, outputs); - res->vals = *addr; - res->lens.push_back(res->vals.size()); -} - -template -void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) { - if (weights_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid embedding table key " << key; - return; - } - if (embedding_lookup_ops_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; - return; - } - WeightPtr table_ptr = weights_[key]; - MS_EXCEPTION_IF_NULL(table_ptr); - std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; - MS_EXCEPTION_IF_NULL(table_lookup_op); - table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size()); -} - -template -inline bool ParameterServer::ReadyForUpdateWeights() { - return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); -} - -template -inline bool ParameterServer::ReadyForPush(const Key &key) { - std::unique_lock lock(mutex_); - if (weights_.empty()) { - MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send " - "kInitWeightsCmd command. 2.The Server failed to initialize weights."; - } - return grad_accum_count_ < weights_.size() && tokens_[key] <= 0; -} - -template -inline bool ParameterServer::ReadyForPull(const Key &key) { - std::unique_lock lock(mutex_); - if (tokens_.count(key) == 0 || weights_[key] == 0) { - MS_LOG(EXCEPTION) << "Invalid weight key " << key; - } - return tokens_[key] > 0; -} - -template -inline void ParameterServer::ResetGradAccumCount() { - grad_accum_count_ = 0; - for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { - grads_accum_counter_[iter->first] = 0; - } -} - -template -inline std::mutex &ParameterServer::mutex() { - return mutex_; -} - -template -void ParameterServer::GetEmbeddingTableParamPtr() { - MS_EXCEPTION_IF_NULL(func_graph_); - auto cnodes = func_graph_->GetOrderedCnodes(); - Key count = 0; - for (auto cnode : cnodes) { - MS_EXCEPTION_IF_NULL(cnode); - std::string cnode_name = AnfAlgo::GetCNodeName(cnode); - if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) { - auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); - if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) { - auto embedding_cnode = embedding_table->cast(); - embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0); - } - MS_EXCEPTION_IF_NULL(embedding_table); - if (embedding_table->isa()) { - MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; - embedding_tables_.insert(std::make_pair(count, embedding_table->cast())); - count++; - } - } - } -} - -template -void ParameterServer::SyncEmbeddingTables() { - for (auto embedding_table : embedding_tables_) { - Key key = embedding_table.first; - if (embedding_lookup_ops_.count(key) == 0) { - MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key; - continue; - } - auto lookup = embedding_lookup_ops_[key]; - const std::vector &input_shapes = lookup->input_sizes(); - std::vector new_tensor_shape(input_shapes.begin(), input_shapes.end()); - - tensor::TensorPtr new_tensor = std::make_shared(kNumberTypeFloat32, new_tensor_shape); - MS_EXCEPTION_IF_NULL(new_tensor); - float *new_tensor_data_ptr = reinterpret_cast(new_tensor->data_c()); - size_t new_tensor_size = static_cast(new_tensor->data().nbytes()); - size_t embedding_table_size = weights_[key]->size() * sizeof(float); - if (new_tensor_size != embedding_table_size) { - MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size - << ", embedding_table size:" << embedding_table_size; - } - MS_EXCEPTION_IF_NULL(new_tensor_data_ptr); - MS_EXCEPTION_IF_NULL(weights_[key]->data()); - int64_t ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; - } - - auto paramter_tensor_ptr = embedding_table.second->default_param(); - MS_EXCEPTION_IF_NULL(paramter_tensor_ptr); - paramter_tensor_ptr->cast()->AssignValue(*new_tensor); - } -} - -template -void ParameterServer::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; - ::ps::Start(0); - MS_LOG(INFO) << "PServer connected successfully."; - if (!::ps::IsServer()) { - std::cout << "This is not ther Server" << std::endl; - return; - } - Init(func_graph); - PSContext::instance()->SetPSRankId(rank_id_); - thread_->join(); - SyncEmbeddingTables(); - MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; - ::ps::Finalize(0, true); - MS_LOG(INFO) << "PServer finalized successfully."; -} -} // namespace ps -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ +#define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "backend/session/session_basic.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/session_factory.h" +#include "ps/optimizer_info.h" +#include "ps/optimizer_info_builder.h" +#include "ps/ps_context.h" +#include "runtime/device/cpu/kernel_select_cpu.h" +#include "utils/ms_context.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" +#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "ps/random_normal/random_normal.h" + +#include "ps/constants.h" +#include "ps/util.h" +#include "ps/embedding_table_shard_metadata.h" +#include "utils/log_adapter.h" +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/core/server_node.h" + +namespace mindspore { +namespace ps { + +class ParameterServer { + public: + static ParameterServer &GetInstance() { + static ParameterServer instance; + return instance; + } + + void Run(const FuncGraphPtr &func_graph); + + private: + ParameterServer() + : pserver_num_(0), + worker_num_(0), + rank_id_(0), + grad_accum_count_(0), + handler_(nullptr), + func_graph_(nullptr), + sess_(nullptr), + running_(true), + thread_(nullptr) {} + ~ParameterServer() = default; + ParameterServer(const ParameterServer &) = delete; + ParameterServer &operator=(const ParameterServer &) = delete; + + class ServerHandler { + public: + explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + ~ServerHandler() = default; + void Init(); + void operator()(std::shared_ptr conn, std::shared_ptr meta, DataPtr data, + size_t size); + void HandlePushReq(DataPtr data, size_t size, VectorPtr res); + void HandlePullReq(DataPtr data, size_t size, VectorPtr res); + void HandleInitWeights(DataPtr data, size_t size, VectorPtr res); + void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res); + void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res); + void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res); + void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res); + void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res); + void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res); + void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res); + void HandleFinalize(DataPtr data, size_t size, VectorPtr res); + + private: + ParameterServer *ps_; + typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res); + std::unordered_map handlers_; + std::unordered_map init_weights_; + std::unordered_map init_weight_to_optim_; + std::unordered_map init_optim_info_; + }; + + bool Init(const FuncGraphPtr &func_graph); + void InitOptimInfoBuilders(); + void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id); + void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); + void InitWeight(const Key &key, const WeightPtr &weight); + void InitGrad(const Key &key, const GradPtr &grad); + void InitEmbeddingTable(const Key &key, + const std::shared_ptr>>> &shapes, + const ParamInitInfo ¶m_init_info); + bool HasWeight(const Key &key); + void Finalize(); + void UpdateWeights(); + void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); + WeightPtr weight(const Key &key); + void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res); + void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); + bool ReadyForUpdateWeights(); + bool ReadyForPush(const Key &key); + bool ReadyForPull(const Key &key); + void ResetGradAccumCount(); + const CNodePtr GetCNode(const std::string &name) const; + std::mutex &mutex(); + void GetEmbeddingTableParamPtr(); + void SyncEmbeddingTables(); + + size_t pserver_num_; + size_t worker_num_; + size_t rank_id_; + size_t grad_accum_count_; + std::unique_ptr handler_; + FuncGraphPtr func_graph_; + std::shared_ptr sess_; + bool running_; + + std::unordered_map> optimizers_; + std::unordered_map optim_inputs_shape_; + std::unordered_map original_optim_inputs_shape_; + std::unordered_map> optim_infos_; + std::unordered_map> optim_info_builders_; + std::unordered_map weight_key_to_optims_; + std::unordered_map weight_key_to_optim_op_; + std::unordered_map weights_; + std::unordered_map is_embedding_; + std::unordered_map grads_; + std::unordered_map grads_accum_counter_; + std::unordered_map> embedding_lookup_ops_; + std::unordered_map tokens_; + + std::mutex mutex_; + std::condition_variable apply_grads_cv_; + + std::unique_ptr thread_; + core::ServerNode server_node_; + std::map embedding_tables_; + + friend class ServerHandler; +}; +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 96e4e23c30..720e026d7e 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -145,7 +145,6 @@ const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) void PsCacheManager::Initialize() { MS_LOG(INFO) << "PS cache initialize."; if (!worker.running()) { - Util::SetInternalEnvVar(); worker.Run(); } embedding_device_cache_ = std::make_shared(batch_elements_, vocab_cache_size_); @@ -177,22 +176,19 @@ void PsCacheManager::InitParameterServer() { for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; size_t key = worker.SetParamKey(param_name); - std::vector keys{key, key, key, key, key, key}; - std::vector values{ - SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; - std::vector lens{2, 2, 3}; const auto &hash_table_info = item.second; const auto ¶m_init_info = hash_table_info.param_init_info_; - if (param_init_info.param_type_ == kWeight) { - lens.push_back(1); - } else if (param_init_info.param_type_ == kAccumulation) { - lens.push_back(2); - } - values.push_back(param_init_info.init_val_); - lens.push_back(param_init_info.global_seed_); - lens.push_back(param_init_info.op_seed_); + + std::vector input_shape = {item.second.vocab_size, item.second.embedding_size}; + std::vector indices_shape = {1, 1}; + std::vector output_shape = {1, 1, 1}; + ParamInitInfoMessage info; + info.set_param_type(param_init_info.param_type_); + info.set_init_val(param_init_info.init_val_); + info.set_global_seed(param_init_info.global_seed_); + info.set_op_seed(param_init_info.op_seed_); // if worker role - worker.InitPSEmbeddingTable(keys, values, lens); + worker.InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info); } finish_init_parameter_server_ = true; @@ -245,7 +241,7 @@ void PsCacheManager::AllocMemForHashTable() { } void PsCacheManager::SetLocalIdRank() { - auto worker_num = ::ps::NumWorkers(); + auto worker_num = PSContext::instance()->initial_worker_num(); auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num)); vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_); emb_table_slice_bounds_.first = local_shard_size * rank_id_; @@ -829,8 +825,8 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_ if (swap_indices_size == 0) { return true; } - ::ps::SArray lookup_ids(swap_indices_size, 0); - ::ps::SArray swap_out_data; + std::vector lookup_ids(swap_indices_size, 0); + std::vector swap_out_data; auto embedding_size = hash_info.embedding_size; swap_out_data.resize(swap_indices_size * embedding_size); auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); @@ -857,22 +853,21 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ } auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); auto embedding_size = hash_info.embedding_size; - ::ps::SArray lengths{swap_indices_size}; - ::ps::SArray lookup_result(swap_indices_size * embedding_size, 0); - ::ps::SArray lookup_ids(swap_indices_size, 0); + std::vector lookup_result(swap_indices_size * embedding_size, 0); + std::vector lookup_ids(swap_indices_size, 0); auto copy_len = swap_indices_size * sizeof(int); auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; } - worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); + worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), host_hash_table_addr)); return true; } -bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, +bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector *swap_out_data, const HashTableInfo &hash_info) { MS_ERROR_IF_NULL(swap_out_index); MS_ERROR_IF_NULL(swap_out_data); @@ -912,16 +907,15 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons auto cache_vocab_size = hash_info.cache_vocab_size; auto embedding_size = hash_info.embedding_size; // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). - ::ps::SArray lengths{swap_in_ids_size}; - ::ps::SArray lookup_result(swap_in_ids_size * embedding_size, 0); - ::ps::SArray lookup_ids(swap_in_ids_size, 0); + std::vector lookup_result(swap_in_ids_size * embedding_size, 0); + std::vector lookup_ids(swap_in_ids_size, 0); auto copy_len = swap_in_ids_size * sizeof(int); auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; } - worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); + worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); // Hash swap-in in device. RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(), @@ -934,7 +928,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons return true; } -bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key) { +bool PsCacheManager::UpdataEmbeddingTable(const std::vector &swap_out_data, int *swap_out_ids, size_t key) { MS_ERROR_IF_NULL(embedding_device_cache_); MS_ERROR_IF_NULL(embedding_device_cache_->cache_); MS_ERROR_IF_NULL(swap_out_ids); @@ -942,7 +936,7 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_da if (swap_out_ids_size == 0) { return true; } - ::ps::SArray lookup_ids(swap_out_ids_size, 0); + std::vector lookup_ids(swap_out_ids_size, 0); auto copy_len = swap_out_ids_size * sizeof(int); auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); if (ret != EOK) { @@ -994,8 +988,8 @@ bool PsCacheManager::SyncHostEmbeddingTable() { continue; } auto key = worker.GetParamKey(item.first); - ::ps::SArray lookup_ids(swap_indices_lens, 0); - ::ps::SArray swap_out_data; + std::vector lookup_ids(swap_indices_lens, 0); + std::vector swap_out_data; auto embedding_size = hash_info.embedding_size; swap_out_data.resize(swap_indices_lens * embedding_size); auto host_hash_table_addr = hash_info.host_address.get(); @@ -1038,8 +1032,8 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { continue; } auto key = worker.GetParamKey(item.first); - ::ps::SArray lookup_ids(swap_indices_lens, 0); - ::ps::SArray swap_out_data; + std::vector lookup_ids(swap_indices_lens, 0); + std::vector swap_out_data; auto embedding_size = hash_info.embedding_size; swap_out_data.resize(swap_indices_lens * embedding_size); std::unique_ptr device_hash_table_addr_tmp = diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index c78efafb95..2e620a5273 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -29,9 +29,9 @@ #include "backend/kernel_compiler/kernel.h" #include "utils/shape_utils.h" #include "ir/tensor.h" -#include "ps/ps.h" -#include "ps/common.h" +#include "ps/constants.h" #include "ps/worker.h" +#include "ps/ps_context.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" #include "ps/ps_cache/embedding_hash_map.h" #include "ps/ps_cache/ps_cache_factory.h" @@ -155,7 +155,7 @@ class PsCacheManager { bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index); bool ParseHostDataHostToDevice(size_t id); bool ParseHostDataDeviceToHost(); - bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, const HashTableInfo &hash_info); + bool HashSwapDeviceOut(int *swap_out_index, std::vector *swap_out_data, const HashTableInfo &hash_info); bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); bool HashSwapHostToDevice(const HashTableInfo &hash_info); bool HashSwapDeviceToHost(const HashTableInfo &hash_info); @@ -165,7 +165,7 @@ class PsCacheManager { float *hash_table_addr); bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, const int *indices_addr, float *output_addr); - bool UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key); + bool UpdataEmbeddingTable(const std::vector &swap_out_data, int *swap_out_ids, size_t key); void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, const int *indices_addr, float *output_addr); bool CheckFinishInsertInitInfo() const; diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 3aabe2a7c9..1c11423bd5 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -48,10 +48,10 @@ void PSContext::SetPSEnable(bool enabled) { MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid."; } - worker_num_ = std::strtol(common::GetEnv("MS_WORKER_NUM").c_str(), nullptr, 10); - server_num_ = std::strtol(common::GetEnv("MS_SERVER_NUM").c_str(), nullptr, 10); - scheduler_host_ = common::GetEnv("MS_SCHED_HOST"); - scheduler_port_ = std::strtol(common::GetEnv("MS_SCHED_PORT").c_str(), nullptr, 10); + worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); + server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); + scheduler_host_ = common::GetEnv(kEnvSchedulerHost); + scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10); } else { MS_LOG(INFO) << "PS mode is disabled."; is_worker_ = false; diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index c1da1e5263..e6106696ce 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -19,6 +19,7 @@ #include #include +#include "ps/constants.h" namespace mindspore { namespace ps { diff --git a/mindspore/ccsrc/ps/scheduler.cc b/mindspore/ccsrc/ps/scheduler.cc index f60845a7ed..2af96f9741 100755 --- a/mindspore/ccsrc/ps/scheduler.cc +++ b/mindspore/ccsrc/ps/scheduler.cc @@ -15,13 +15,16 @@ */ #include "ps/scheduler.h" -#include "ps/ps.h" namespace mindspore { namespace ps { void Scheduler::Run() { - ::ps::Start(0); - ::ps::Finalize(0, true); + core::ClusterMetadata::instance()->Init( + PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), + PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); + scheduler_node_.Start(); + scheduler_node_.Finish(); + scheduler_node_.Stop(); exit(1); } } // namespace ps diff --git a/mindspore/ccsrc/ps/scheduler.h b/mindspore/ccsrc/ps/scheduler.h index 6466c4bf4d..ceb8b033ab 100755 --- a/mindspore/ccsrc/ps/scheduler.h +++ b/mindspore/ccsrc/ps/scheduler.h @@ -16,6 +16,11 @@ #ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_ #define MINDSPORE_CCSRC_PS_SCHEDULER_H_ + +#include "ps/core/scheduler_node.h" +#include "ps/util.h" +#include "ps/ps_context.h" + namespace mindspore { namespace ps { class Scheduler { @@ -32,6 +37,7 @@ class Scheduler { ~Scheduler() = default; Scheduler(const Scheduler &) = delete; Scheduler &operator=(const Scheduler &) = delete; + core::SchedulerNode scheduler_node_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/util.cc b/mindspore/ccsrc/ps/util.cc index fc89d69888..4ad86dc62d 100644 --- a/mindspore/ccsrc/ps/util.cc +++ b/mindspore/ccsrc/ps/util.cc @@ -17,7 +17,7 @@ #include "ps/util.h" #include #include -#include "ps/common.h" +#include "ps/constants.h" #include "ps/ps_context.h" #include "utils/ms_utils.h" @@ -46,50 +46,10 @@ std::unordered_map Util::id_to_optimizer_nodes{ {3, kSparseFtrlOp}, }; -bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_mode(); } - -bool Util::IsRoleOfWorker() { return PSContext::instance()->is_worker(); } - bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); } bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); } -void Util::SetInternalEnvVar() { - if (IsParamServerMode()) { - auto comm_type = common::GetEnv(kEnvCommType); - if (!comm_type.empty()) { - (void)common::SetEnv(kDmlcCommType, comm_type.c_str()); - } - auto interface = common::GetEnv(kEnvInterface); - if (!interface.empty()) { - (void)common::SetEnv(kDmlcInterface, interface.c_str()); - } - auto server_num = common::GetEnv(kEnvPServerNum); - if (!server_num.empty()) { - (void)common::SetEnv(kDmlcPServerNum, server_num.c_str()); - } - auto worker_num = common::GetEnv(kEnvWorkerNum); - if (!worker_num.empty()) { - (void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str()); - } - if (IsRoleOfScheduler()) { - (void)common::SetEnv(kDmlcRole, kRoleOfScheduler); - } else if (IsRoleOfPServer()) { - (void)common::SetEnv(kDmlcRole, kRoleOfPServer); - } else if (IsRoleOfWorker()) { - (void)common::SetEnv(kDmlcRole, kRoleOfWorker); - } - auto scheduler_host = common::GetEnv(kEnvSchedulerHost); - if (!scheduler_host.empty()) { - (void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str()); - } - auto scheduler_port = common::GetEnv(kEnvSchedulerPort); - if (!scheduler_port.empty()) { - (void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str()); - } - } -} - int64_t Util::optimizer_id(std::string name) { if (optimizer_to_ids.count(name) > 0) { return optimizer_to_ids[name]; diff --git a/mindspore/ccsrc/ps/util.h b/mindspore/ccsrc/ps/util.h index 2a832d00f2..ba6034acec 100644 --- a/mindspore/ccsrc/ps/util.h +++ b/mindspore/ccsrc/ps/util.h @@ -37,11 +37,8 @@ struct ParamInitInfo { class Util { public: - static bool IsParamServerMode(); - static bool IsRoleOfWorker(); static bool IsRoleOfPServer(); static bool IsRoleOfScheduler(); - static void SetInternalEnvVar(); static int64_t optimizer_id(std::string name); static std::string optimizer_name(int64_t id); static std::string optimizer_node_name(int64_t id); diff --git a/mindspore/ccsrc/ps/internal/worker.cc b/mindspore/ccsrc/ps/worker.cc similarity index 99% rename from mindspore/ccsrc/ps/internal/worker.cc rename to mindspore/ccsrc/ps/worker.cc index bdd2c3a22c..42ca12c457 100644 --- a/mindspore/ccsrc/ps/internal/worker.cc +++ b/mindspore/ccsrc/ps/worker.cc @@ -14,11 +14,10 @@ * limitations under the License. */ -#include "ps/internal/worker.h" +#include "ps/worker.h" namespace mindspore { namespace ps { -namespace internal { void Worker::Run() { std::lock_guard lock(running_mutex_); core::ClusterMetadata::instance()->Init( @@ -198,7 +197,8 @@ void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) { } void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector &input_shape, - const std::vector &indices_shape, const std::vector &output_shape) { + const std::vector &indices_shape, const std::vector &output_shape, + const ParamInitInfoMessage &info) { bool has_init = IsKeyInit(key); if (has_init) { MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized."; @@ -210,6 +210,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector & *embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()}; *embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()}; *embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()}; + *embedding_table_meta.mutable_info() = info; std::string kv_data = embedding_table_meta.SerializeAsString(); @@ -295,19 +296,18 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector &lookup_ int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size()); std::unordered_map>> id_addr_map; std::shared_ptr> values = std::make_shared>(); + int64_t value_offset = 0; for (size_t i = 0; i < resp.size(); ++i) { KVMessage message; message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size()); - int64_t offset = 0; - values->clear(); for (auto j = 0; j < message.values_size(); j++) { values->push_back(message.values(j)); } - MS_LOG(DEBUG) << "the embedding resp:" << values; + MS_LOG(DEBUG) << "The embedding resp:" << values; for (auto k = 0; k < message.keys_size(); k++) { const Key &key = message.keys(k); - float *addr = values->data() + offset; - offset += single_id_len; + float *addr = values->data() + value_offset; + value_offset += single_id_len; id_addr_map[key] = std::make_shared>(std::make_pair(addr, single_id_len)); } } @@ -969,6 +969,5 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa } } } -} // namespace internal } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/worker.h b/mindspore/ccsrc/ps/worker.h index a3b1930c5d..bef27a9c88 100644 --- a/mindspore/ccsrc/ps/worker.h +++ b/mindspore/ccsrc/ps/worker.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,24 +25,38 @@ #include #include #include -#include "ps/ps.h" +#include +#include +#include + #include "utils/log_adapter.h" #include "ir/tensor.h" #include "ps/util.h" -#include "ps/common.h" -#include "ps/worker_proxy.h" +#include "ps/constants.h" #include "utils/shape_utils.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "ps/core/worker_node.h" +#include "ps/embedding_table_shard_metadata.h" +#include "proto/comm.pb.h" +#include "proto/ps.pb.h" +#include "ps/ps_context.h" namespace mindspore { namespace ps { -template class Worker { public: static Worker &GetInstance() { static Worker instance; return instance; } + using Callback = std::function; + using PartitionEmbeddingMessages = std::vector>; + using PartitionKVMessages = std::vector>; + + using EmbeddingPartitioner = std::function &attrs)>; + using KVPartitioner = + std::function &attrs)>; void Run(); void Push(const std::vector &keys, std::vector addrs, const ShapeVector &sizes); @@ -53,340 +67,89 @@ class Worker { bool GetParamInitInServer(const std::string ¶m_name); void SetKeyOptimId(size_t key, const std::string &optimizer_name); void SetOptimInputShapes(size_t key, const ShapeVector &shape); - void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); - void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const ShapeVector &sizes); + void AddEmbeddingTable(const Key &key, const size_t &row_count); + void InitPSEmbeddingTable(const size_t &key, const std::vector &input_shape, + const std::vector &indices_shape, const std::vector &output_shape, + const ParamInitInfoMessage &info); void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); - void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *lookup_result, int64_t cmd); - void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &vals); + void DoPSEmbeddingLookup(const Key &key, const std::vector &lookup_ids, std::vector *lookup_result, + int64_t cmd); + void UpdateEmbeddingTable(const std::vector &keys, const std::vector &lookup_ids, + const std::vector &vals); + bool running() { return running_; } void Finalize(); private: - Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} + Worker() : running_(false), key_cnt_(0) {} ~Worker() = default; Worker(const Worker &) = delete; Worker &operator=(const Worker &) = delete; + void Initialize(); bool IsKeyInit(const size_t key); + void AddKeyToServerId(const Key &key); + void AddKeyByHashMod(const Key &key); void InitPSOptimId(const size_t param_key); void InitPSOptimInputShapes(const size_t key); void InitPSParamData(const std::vector &keys, void *origin_addr, size_t size); - static void EmbeddingLookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &ranges, - std::vector>> *sliced) {} - - std::shared_ptr> kv_worker_; + bool IsReadyForPush(const Key &key); + bool IsReadyForPull(const Key &key); + void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set &distinct_ids, + const std::vector> &indice_to_grads, const int *all_indice, + const size_t segment_size, float *gradient, int *indices); + void BuildSparseValue(const std::vector &lengths, const size_t grad_index, const size_t indice_index, + const float *original_data, const float *grads, int *indices, std::vector *reduced_data); + + void PushData(const std::vector &keys, const std::vector &vals, const std::vector &lens = {}, + int command = 0, int64_t priority = 0); + void PushSparseData(const std::vector &keys, const std::vector &vals, const std::vector &lens, + size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); + void PullData(const std::vector &keys, std::vector *vals, std::vector *lens = nullptr, int cmd = 0, + int64_t priority = 0); + + void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, + const std::map &attrs); + + void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition, + const std::map &attrs); + void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition, + const std::map &attrs); + void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector> *partition, + const std::map &attrs); + void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition, + const std::map &attrs); + void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition, + const std::map &attrs); + void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner, + const std::map &attrs); + void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner, + const std::map &attrs, std::vector *vals, std::vector *lens); + + int64_t server_num_; bool running_; + std::mutex running_mutex_; size_t key_cnt_; std::map param_to_key_; std::map init_keys_; std::map key_to_optimId_; std::map> key_to_optim_shapes_; std::map param_to_init_in_server_; + core::WorkerNode worker_node_; + + EmbeddingPartitioner lookup_partitioner_; + KVPartitioner sparse_partitioner_; + KVPartitioner round_robin_partitioner_; + KVPartitioner worker_init_embedding_partitioner_; + KVPartitioner update_embedding_partitioner_; + KVPartitioner broadcast_partitioner_; + std::unordered_map key_to_server_id_; + std::unordered_map embedding_row_cnt_; + + std::unordered_map>> embedding_table_ranges_; }; -template -void Worker::Run() { - if (running_) { - MS_LOG(INFO) << "'Worker is already running."; - return; - } - MS_LOG(INFO) << "Worker starts connecting to scheduler and server..."; - ::ps::Start(0); - MS_LOG(INFO) << "Worker connected successfully."; - if (!::ps::IsWorker()) { - MS_LOG(EXCEPTION) << "The role is not worker."; - } - kv_worker_ = std::make_shared>(0, 0, 1, 2); - running_ = true; -} - -template -void Worker::Push(const std::vector &keys, std::vector addrs, const ShapeVector &sizes) { - if (keys.size() == 0) { - MS_LOG(EXCEPTION) << "key size should be greater than zero"; - } - if (key_to_optimId_.count(keys[0]) == 0) { - MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0]; - } - Key key = keys[0]; - int64_t optim_id = key_to_optimId_[key]; - bool is_sparse = false; - if (optim_id == 1 || optim_id == 2 || optim_id == 3) { - is_sparse = true; - } - int64_t grad_index = -1; - int64_t indice_index = -1; - - // Sparse adam gradient - if (optim_id == 1 || optim_id == 2) { - grad_index = 6; - indice_index = 7; - - // Sparse ftrl gradient - } else if (optim_id == 3) { - grad_index = 0; - indice_index = 1; - } - - size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus()); - ::ps::SArray total_buffer(total_size, 0); - size_t offset = 0; - size_t dst_size = 0; - size_t src_size = 0; - for (size_t i = 0; i < sizes.size(); i++) { - void *dst_data = total_buffer.data() + offset / sizeof(T); - void *src_data = reinterpret_cast(addrs[i]); - MS_EXCEPTION_IF_NULL(dst_data); - MS_EXCEPTION_IF_NULL(src_data); - dst_size = sizes[i] * sizeof(T); - src_size = sizes[i] * sizeof(T); - auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; - } - offset += sizes[i] * sizeof(T); - } - - while (!kv_worker_->IsReadyForPush(keys[0])) { - continue; - } - std::vector sizes_int; - (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), - [](const int64_t &value) { return static_cast(value); }); - if (!is_sparse) { - kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes_int)); - } else { - std::vector &var_shape = key_to_optim_shapes_[key][0]; - int64_t first_dim_size = var_shape[0]; - int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies()); - kv_worker_->PushSparseData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes_int), grad_index, - indice_index, first_dim_size, outer_dim_size); - } -} - -template -void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { - MS_EXCEPTION_IF_NULL(dev_addr); - ::ps::SArray variables(size / sizeof(T), 0); - while (!kv_worker_->IsReadyForPull(key)) { - continue; - } - kv_worker_->PullData({key}, &variables); - size_t dst_size = size; - size_t src_size = size; - auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; - } -} - -template -void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *lookup_result, int64_t cmd) { - MS_EXCEPTION_IF_NULL(lookup_result); - kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); -} - -template -void Worker::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &vals) { - kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals); -} - -template -void Worker::Finalize() { - if (running_) { - MS_LOG(INFO) << "Worker starts finalizing..."; - kv_worker_->Finalize(); - kv_worker_.reset(); - running_ = false; - MS_LOG(INFO) << "Worker finalized successfully."; - } -} - -template -void Worker::InitPSParamData(const std::vector &keys, void *origin_addr, size_t size) { - MS_EXCEPTION_IF_NULL(origin_addr); - ::ps::SArray addr(reinterpret_cast(origin_addr), size / sizeof(T)); - ::ps::SArray<::ps::Key> key(keys); - ::ps::SArray lens; - lens.push_back(addr.size()); - kv_worker_->PushData(key, addr, lens, kInitWeightsCmd); - init_keys_[key[0]] = true; -} - -template -void Worker::SetOptimInputShapes(size_t key, const ShapeVector &shape) { - if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) { - key_to_optim_shapes_[key] = {shape}; - } else { - key_to_optim_shapes_[key].push_back(shape); - } -} - -template -void Worker::InitPSOptimInputShapes(const size_t key) { - ::ps::SArray<::ps::Key> keys; - ::ps::SArray shape_len; - ::ps::SArray all_shape; - std::vector shapes = key_to_optim_shapes_[key]; - for (auto shape : shapes) { - keys.push_back(key); - if (shape.size() == 0) { - shape_len.push_back(1); - all_shape.push_back(1); - } else { - shape_len.push_back(SizeToLong(shape.size())); - for (auto dim : shape) { - all_shape.push_back(static_cast(dim)); - } - } - } - MS_LOG(INFO) << "keys:" << keys; - MS_LOG(INFO) << "shape_len:" << shape_len; - MS_LOG(INFO) << "all_shape:" << all_shape; - if (!init_keys_[key]) { - init_keys_[key] = true; - } - kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); -} - -template -bool Worker::IsKeyInit(const size_t key) { - if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) { - return false; - } - return true; -} - -template -size_t Worker::SetParamKey(const std::string ¶m_name) { - size_t key = UINT64_MAX; - if (param_to_key_.count(param_name)) { - key = param_to_key_[param_name]; - MS_LOG(INFO) << param_name << " key is already set: key value is " << key; - } else { - key = key_cnt_++; - param_to_key_[param_name] = key; - MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name; - } - return key; -} - -template -void Worker::SetParamInitInServer(const std::string ¶m_name, bool init_in_server) { - MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server; - param_to_init_in_server_[param_name] = init_in_server; -} - -template -bool Worker::GetParamInitInServer(const std::string ¶m_name) { - if (param_to_init_in_server_.count(param_name) == 0) { - return false; - } - return param_to_init_in_server_[param_name]; -} - -template -size_t Worker::GetParamKey(const std::string ¶m_name) { - size_t key = kInvalidKey; - if (param_to_key_.find(param_name) != param_to_key_.end()) { - key = param_to_key_[param_name]; - MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key; - } - return key; -} - -template -void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) { - key_to_optimId_[key] = Util::optimizer_id(optimizer_name); -} - -template -void Worker::InitPSOptimId(const size_t param_key) { - if (key_to_optimId_.count(param_key) == 0) { - MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key; - } - int64_t optim_id = key_to_optimId_[param_key]; - - ::ps::SArray<::ps::Key> keys = {param_key}; - ::ps::SArray optim_id_vals = {static_cast(optim_id)}; - ::ps::SArray optim_id_lens = {optim_id_vals.size()}; - kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd); -} - -template -void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const ShapeVector &sizes) { - bool has_init = IsKeyInit(keys[0]); - if (has_init) { - MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; - return; - } - ::ps::SArray shapes_val; - for (auto dim : shapes) { - shapes_val.push_back(dim); - } - std::vector sizes_int; - (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), - [](const int64_t &value) { return static_cast(value); }); - kv_worker_->Wait( - kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray(sizes_int))); -} - -template -void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) { - MS_EXCEPTION_IF_NULL(tensor); - MS_EXCEPTION_IF_NULL(input_node); - auto pk_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(pk_node); - const std::string ¶m_name = pk_node->fullname_with_scope(); - void *param_data = tensor->data_c(); - size_t param_size = LongToSize(tensor->data().nbytes()); - - size_t param_key = GetParamKey(param_name); - if (param_key == kInvalidKey) { - MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned."; - return; - } - bool init_in_server = false; - auto param_info_ptr = pk_node->param_info(); - if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) { - init_in_server = true; - } - SetParamInitInServer(param_name, init_in_server); - bool init = IsKeyInit(param_key); - if (!init) { - MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name - << ", whether init in server: " << init_in_server; - kv_worker_->AddKeyToServerId(param_key); - if (!PsDataPrefetch::GetInstance().cache_enable()) { - if (!init_in_server) { - if (param_size > INT_MAX) { - MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " - << param_size; - } - InitPSParamData({param_key}, param_data, param_size); - } - InitPSOptimId(param_key); - InitPSOptimInputShapes(param_key); - } - } -} - -template -void Worker::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { - bool has_init = IsKeyInit(key); - if (has_init) { - return; - } - kv_worker_->AddEmbeddingTable(key, row_count); -} - -static Worker &worker = Worker::GetInstance(); +static Worker &worker = Worker::GetInstance(); } // namespace ps } // namespace mindspore #endif // MINDSPORE_CCSRC_PS_WORKER_H_ diff --git a/mindspore/ccsrc/ps/worker_proxy.h b/mindspore/ccsrc/ps/worker_proxy.h deleted file mode 100644 index 051308eda1..0000000000 --- a/mindspore/ccsrc/ps/worker_proxy.h +++ /dev/null @@ -1,873 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PS_WORKER_PROXY_H_ -#define MINDSPORE_CCSRC_PS_WORKER_PROXY_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ps/ps.h" -#include "ps/util.h" -#include "backend/kernel_compiler/common_utils.h" -#include "ps/ps_context.h" - -namespace mindspore { -namespace ps { -template -class WorkerProxy : public ::ps::KVWorker { - public: - using Worker = ::ps::KVWorker; - using Callback = std::function; - using SlicedKVs = std::vector>>; - using Slicer = std::function &send, const std::vector<::ps::Range> &ranges, - SlicedKVs *sliced, const std::map &attrs)>; - using ::ps::SimpleApp::obj_; - explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id) - : Worker(app_id, customer_id) { - server_num_ = ::ps::NumServers(); - MS_LOG(INFO) << "Server num:" << server_num_; - PSContext::instance()->SetPSRankId(::ps::MyRank()); - using std::placeholders::_1; - using std::placeholders::_2; - using std::placeholders::_3; - using std::placeholders::_4; - using std::placeholders::_5; - lookup_customer_ = std::unique_ptr<::ps::Customer>( - new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); - general_customer_ = std::unique_ptr<::ps::Customer>( - new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy::ProcessResponse, this, _1))); - lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3, _4, _5); - sparse_slicer_ = std::bind(&WorkerProxy::SparseSlicer, this, _1, _2, _3, _4, _5); - broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3, _4, _5); - round_robin_slicer_ = std::bind(&WorkerProxy::RoundRobinSlicer, this, _1, _2, _3, _4, _5); - worker_init_embedding_slicer_ = std::bind(&WorkerProxy::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); - update_embedding_slicer_ = std::bind(&WorkerProxy::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5); - } - ~WorkerProxy() override = default; - - void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); - void AddKeyToServerId(const ::ps::Key &key); - void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *outs, int64_t cmd = 0, - const Callback &cb = nullptr, int64_t priority = 0); - int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, - const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int64_t priority = 0); - void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &vals, const Callback &cb = nullptr, int64_t priority = 0); - bool IsReadyForPush(const Key &key); - bool IsReadyForPull(const Key &key); - void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, - int64_t cmd = 0, int64_t priority = 0); - void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens, - size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); - void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens = nullptr, - int64_t cmd = 0, int64_t priority = 0); - void Finalize(); - - private: - template - int64_t AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int64_t cmd, - const Callback &cb); - int64_t AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens, - int64_t cmd, const Callback &cb); - void LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, const std::map &attrs); - void SparseSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, const std::map &attrs); - void BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, const std::map &attrs); - void RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attrs); - void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attrs); - void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attrs); - void ProcessLookupResult(const ::ps::Message &msg); - void ProcessResponse(const ::ps::Message &msg); - void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs &kvs, - const Slicer &slicer, std::map attrs = {}); - void AddKeyByHashMod(const ::ps::Key &key); - - void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set &distinct_ids, - const std::vector> &indice_to_grad, const int *all_indice, - const size_t segment_size, T *gradient, int *indice); - void BuildSparseValue(const ::ps::SArray &lengths, const size_t grad_index, const size_t indice_index, - const T *original_data, const T *grads, int *indices, ::ps::SArray *reduced_data); - - int64_t server_num_; - std::unique_ptr<::ps::Customer> lookup_customer_; - std::unique_ptr<::ps::Customer> general_customer_; - std::unordered_map<::ps::Key, std::shared_ptr>> embedding_table_ranges_; - std::unordered_map>> lookup_results_; - std::unordered_map>> gathered_response_; - std::mutex mutex_; - Slicer lookup_slicer_; - Slicer sparse_slicer_; - Slicer broadcast_slicer_; - Slicer round_robin_slicer_; - Slicer worker_init_embedding_slicer_; - Slicer update_embedding_slicer_; - std::unordered_map lookup_callbacks_; - std::unordered_map general_callbacks_; - std::unordered_map expected_result_count_; - std::unordered_map<::ps::Key, int64_t> key_to_server_id_; - std::unordered_map<::ps::Key, size_t> embedding_row_cnt_; -}; - -template -void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { - uint64_t begin = 0; - uint64_t end = 0; - for (int64_t i = 0; i < server_num_; i++) { - int64_t local_row_cnt = Util::LocalShard(row_count, i, server_num_); - if (i == 0) { - end = local_row_cnt - 1; - } else { - begin = end + 1; - end += local_row_cnt; - } - ::ps::Range range(begin, end); - if (embedding_table_ranges_.count(key) == 0) { - embedding_table_ranges_[key] = std::make_shared>(); - MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]); - } - embedding_table_ranges_[key]->push_back(range); - } - embedding_row_cnt_[key] = row_count; -} - -template -void WorkerProxy::AddKeyByHashMod(const ::ps::Key &key) { - if (server_num_ == 0) { - MS_LOG(EXCEPTION) << "Server number is invalid:0"; - } - key_to_server_id_[key] = static_cast(key % server_num_); - MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key]; -} - -template -void WorkerProxy::AddKeyToServerId(const ::ps::Key &key) { - AddKeyByHashMod(key); -} - -template -void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *outs, int64_t cmd, - const Callback &cb, int64_t priority) { - int64_t ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.lens = lookup_ids; - kvs.priority = priority; - expected_result_count_[ts] = 0; - Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_); - int64_t expect_rt_count = expected_result_count_[ts]; - lookup_customer_->AddResponse(ts, server_num_ - expect_rt_count); - lookup_customer_->WaitRequest(ts); - expected_result_count_.erase(ts); -} - -template -int64_t WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, - const ::ps::SArray &lens, const Callback &cb, int64_t priority) { - int64_t ts = obj_->NewRequest(::ps::kServerGroup); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.vals = vals; - kvs.lens = lens; - kvs.priority = priority; - Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_); - return ts; -} - -template -void WorkerProxy::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &vals, const Callback &cb, int64_t priority) { - int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.lens = lookup_ids; - kvs.vals = vals; - kvs.priority = priority; - expected_result_count_[ts] = 0; - Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_); - if (expected_result_count_[ts] < server_num_) { - general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); - } - general_customer_->WaitRequest(ts); - expected_result_count_.erase(ts); -} - -template -bool WorkerProxy::IsReadyForPush(const Key &key) { - ::ps::SArray result(1, 0); - PullData({key}, &result, nullptr, kCheckReadyForPushCmd); - if (result[0] > 0) { - return true; - } else { - return false; - } -} - -template -bool WorkerProxy::IsReadyForPull(const Key &key) { - ::ps::SArray result(1, 0); - PullData({key}, &result, nullptr, kCheckReadyForPullCmd); - if (result[0] > 0) { - return true; - } else { - return false; - } -} - -template -void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, - const ::ps::SArray &lens, int64_t cmd, int64_t priority) { - int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, cmd, nullptr); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.vals = vals; - kvs.lens = lens; - kvs.priority = priority; - if (embedding_table_ranges_.count(keys[0])) { - if (cmd == kInitWeightsCmd) { - Send(general_customer_.get(), ts, true, false, cmd, kvs, worker_init_embedding_slicer_); - } else { - Send(general_customer_.get(), ts, true, false, cmd, kvs, broadcast_slicer_); - } - } else { - Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_); - } - if (expected_result_count_[ts] < server_num_) { - general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); - } - general_customer_->WaitRequest(ts); -} - -template -void WorkerProxy::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, - const ::ps::SArray &lens, size_t grad_index, size_t indice_index, - size_t first_dim_size, size_t outer_dim_size) { - int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.vals = vals; - kvs.lens = lens; - const int64_t cmd = 0; - if (embedding_table_ranges_.count(keys[0])) { - std::map attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}}; - Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs); - } else { - Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_); - } - if (expected_result_count_[ts] < server_num_) { - general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); - } - general_customer_->WaitRequest(ts); -} - -template -void WorkerProxy::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, ::ps::SArray *lens, - int64_t cmd, int64_t priority) { - MS_EXCEPTION_IF_NULL(vals); - int64_t ts = AddGeneralRspCB(keys, vals, lens, cmd, nullptr); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.priority = priority; - if (embedding_table_ranges_.count(keys[0])) { - Send(general_customer_.get(), ts, false, true, cmd, kvs, broadcast_slicer_); - } else { - Send(general_customer_.get(), ts, false, true, cmd, kvs, round_robin_slicer_); - } - if (expected_result_count_[ts] < server_num_) { - general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); - } - general_customer_->WaitRequest(ts); -} - -template -void WorkerProxy::Finalize() { - int64_t ts = obj_->NewRequest(::ps::kServerGroup); - ::ps::KVPairs kvs; - kvs.keys.push_back(0); - kvs.vals.push_back(0.0f); - Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); - obj_->WaitRequest(ts); - ::ps::Finalize(0, true); -} - -template -template -int64_t WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - C *lookup_result, int64_t cmd, const Callback &cb) { - MS_EXCEPTION_IF_NULL(lookup_result); - int64_t ts = lookup_customer_->NewRequest(::ps::kServerGroup); - const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { - mutex_.lock(); - auto &kvs = lookup_results_[ts]; - mutex_.unlock(); - - if (lookup_ids.empty()) { - MS_LOG(EXCEPTION) << "Lookup id is empty."; - } - int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size()); - std::unordered_map>> id_addr_map; - for (const auto &s : kvs) { - int64_t offset = 0; - for (size_t i = 0; i < s.keys.size(); i++) { - const Key &key = s.keys[i]; - T *addr = s.vals.data() + offset; - offset += single_id_len; - id_addr_map[key] = std::make_shared>(std::make_pair(addr, single_id_len)); - MS_EXCEPTION_IF_NULL(id_addr_map[key]); - } - } - - T *result_addr = lookup_result->data(); - MS_EXCEPTION_IF_NULL(result_addr); - int64_t offset = 0; - size_t dst_size = 0; - size_t src_size = 0; - void *dst_data = nullptr; - void *src_data = nullptr; - for (size_t i = 0; i < lookup_ids.size(); i++) { - if (id_addr_map.count(lookup_ids[i]) == 0) { - offset += single_id_len; - continue; - } - auto &pair = id_addr_map[static_cast(lookup_ids[i])]; - int64_t size = single_id_len * sizeof(T); - dst_size = size; - src_size = size; - dst_data = result_addr + offset; - src_data = pair->first; - MS_EXCEPTION_IF_NULL(dst_data); - MS_EXCEPTION_IF_NULL(src_data); - auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; - } - offset += single_id_len; - } - - mutex_.lock(); - lookup_results_.erase(ts); - mutex_.unlock(); - if (cb) cb(); - }; - lookup_callbacks_[ts] = callback; - return ts; -} - -template -int64_t WorkerProxy::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray *vals, - ::ps::SArray *lens, int64_t cmd, const Callback &cb) { - int64_t ts = general_customer_->NewRequest(::ps::kServerGroup); - const auto &callback = [this, ts, keys, vals, lens, cb]() mutable { - mutex_.lock(); - std::map> server_kvs = gathered_response_[ts]; - mutex_.unlock(); - - vals->clear(); - for (auto kvs : server_kvs) { - for (auto val : kvs.second.vals) { - vals->push_back(val); - } - if (lens) { - for (auto len : kvs.second.lens) { - lens->push_back(len); - } - } - } - - mutex_.lock(); - gathered_response_.erase(ts); - mutex_.unlock(); - if (cb) { - cb(); - } - }; - general_callbacks_[ts] = callback; - return ts; -} - -template -void WorkerProxy::LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attrs) { - MS_EXCEPTION_IF_NULL(sliced); - int32_t *lookup_ids = send.lens.data(); - size_t id_size = send.lens.size(); - - const Key &key = send.keys[0]; - const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); - sliced->resize(ranges.size()); - - for (size_t i = 0; i < ranges.size(); i++) { - const ::ps::Range &range = ranges[i]; - const auto &begin = range.begin(); - const auto &end = range.end(); - std::unordered_set unique_ids; - auto &kvs = sliced->at(i).second; - - kvs.keys.push_back(key); - kvs.vals.push_back(0.0f); - - for (size_t j = 0; j < id_size; j++) { - auto lookup_id = static_cast(lookup_ids[j]); - // If lookup_id is out of range, like negative number, unique_ids will not contain it. - // Servers always get lookup_ids in its embedding table range. - if (lookup_id >= begin && lookup_id <= end) { - unique_ids.insert(lookup_id); - } - } - for (const auto &lookup_id : unique_ids) { - kvs.keys.push_back(lookup_id); - kvs.vals.push_back(0.0f); - } - - if (kvs.keys.size() <= 1) { - sliced->at(i).first = false; - } else { - sliced->at(i).first = true; - expected_result_count_[timestamp] += 1; - } - } -} - -template -void WorkerProxy::SparseSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attrs) { - MS_EXCEPTION_IF_NULL(sliced); - // Init variables - T *data = send.vals.data(); - - if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) { - MS_LOG(EXCEPTION) << "Invalid attrs keys"; - } - auto iter = attrs.find(0); - size_t grad_index = static_cast(iter->second); - iter = attrs.find(1); - size_t indice_index = static_cast(iter->second); - iter = attrs.find(2); - size_t first_dim_size = static_cast(iter->second); - iter = attrs.find(3); - size_t outer_dim_size = static_cast(iter->second); - - int grad_size = send.lens[grad_index]; - int indice_size = send.lens[indice_index]; - int segment_size = grad_size / indice_size; - - int64_t grad_offset = 0; - int64_t indice_offset = 0; - for (size_t i = 0; i < grad_index; i++) { - grad_offset += send.lens[i]; - } - for (size_t j = 0; j < indice_index; j++) { - indice_offset += send.lens[j]; - } - - T *grad_data = data + grad_offset; - int *indice_data = reinterpret_cast(data) + indice_offset; - - // Build the mappings of indice to gradient - std::vector> indice_to_grads; - for (int i = 0; i < indice_size; i++) { - int indice = indice_data[i]; - T *grad = grad_data + i * segment_size; - indice_to_grads.push_back(std::make_pair(indice, grad)); - } - - const Key &key = send.keys[0]; - const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); - sliced->resize(ranges.size()); - - // Construct reduced sparse data for each server - for (size_t i = 0; i < ranges.size(); i++) { - const ::ps::Range &range = ranges[i]; - const auto &begin = range.begin(); - const auto &end = range.end(); - auto &kvs = sliced->at(i).second; - kvs.keys = send.keys; - kvs.lens = send.lens; - - // Prepare the sparse gradient and indice - std::vector indice_ids; - std::unordered_set distinct_ids; - for (int j = 0; j < indice_size; j++) { - size_t indice = static_cast(indice_data[j]); - if (indice >= begin && indice <= end) { - indice_ids.push_back(indice); - distinct_ids.insert(indice); - } - } - size_t indices_size = indice_ids.size(); - if (indices_size > 0) { - int slice_segment_size = indices_size * segment_size; - std::vector src_grad_data(slice_segment_size); - std::vector src_indice_data(indices_size); - PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(), - src_indice_data.data()); - - // Reduce the sparse gradient and indice - std::vector new_grad(slice_segment_size); - std::vector new_indices(indices_size); - mindspore::kernel::SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size}); - Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size, - first_dim_size, outer_dim_size, &unique_sparse_grad); - - // Update the length of reduce sparse gradient and indice - ::ps::SArray reduced_lens; - reduced_lens.CopyFrom(kvs.lens); - reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size; - reduced_lens[indice_index] = unique_sparse_grad.indices_size_; - - // Build the sparse value to be sent - size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus()); - ::ps::SArray reduced_data(total_size, 0); - BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_, - unique_sparse_grad.indices_, &reduced_data); - - kvs.lens = reduced_lens; - kvs.vals = reduced_data; - } - - if (indices_size <= 0) { - ::ps::SArray no_keys; - ::ps::SArray no_vals; - ::ps::SArray no_lens; - no_keys.push_back(key); - no_vals.push_back(-100); - kvs.vals = no_vals; - kvs.lens = no_lens; - } - sliced->at(i).first = true; - expected_result_count_[timestamp] += 1; - } -} - -template -void WorkerProxy::PrepareSparseGradient(const size_t begin, const size_t end, - const std::unordered_set &distinct_ids, - const std::vector> &indice_to_grads, - const int *all_indice, const size_t segment_size, T *gradient, - int *indices) { - MS_EXCEPTION_IF_NULL(all_indice); - MS_EXCEPTION_IF_NULL(gradient); - MS_EXCEPTION_IF_NULL(indices); - int64_t offset = 0; - int64_t index = 0; - size_t segment_data_size = segment_size * sizeof(T); - size_t dst_size; - size_t src_size; - void *dst_data = nullptr; - void *src_data = nullptr; - for (auto &pair : indice_to_grads) { - if (distinct_ids.count(pair.first) == 0) { - continue; - } - indices[index++] = pair.first; - - dst_size = segment_data_size; - src_size = segment_data_size; - dst_data = gradient + offset; - src_data = pair.second; - MS_EXCEPTION_IF_NULL(dst_data); - MS_EXCEPTION_IF_NULL(src_data); - auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size); - if (ret != 0) { - MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; - return; - } - offset += segment_size; - } -} - -template -void WorkerProxy::BuildSparseValue(const ::ps::SArray &lengths, const size_t grad_index, - const size_t indice_index, const T *original_data, const T *grads, int *indices, - ::ps::SArray *reduced_data) { - MS_EXCEPTION_IF_NULL(original_data); - MS_EXCEPTION_IF_NULL(grads); - MS_EXCEPTION_IF_NULL(indices); - MS_EXCEPTION_IF_NULL(reduced_data); - int64_t offset = 0; - size_t dst_size = 0; - size_t src_size = 0; - void *dst_data = nullptr; - void *src_data = nullptr; - for (size_t i = 0; i < lengths.size(); i++) { - if (i != grad_index && i != indice_index) { - int data_size = lengths[i] * sizeof(T); - dst_size = data_size; - src_size = data_size; - dst_data = reduced_data->data() + offset; - src_data = const_cast(original_data) + offset; - MS_EXCEPTION_IF_NULL(dst_data); - MS_EXCEPTION_IF_NULL(src_data); - auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; - } - } - offset += lengths[i]; - } - - // Fill the reduced gradient - int64_t grad_offset = 0; - for (size_t i = 0; i < grad_index; i++) { - grad_offset += lengths[i]; - } - int64_t data_size = lengths[grad_index] * sizeof(T); - dst_size = data_size; - src_size = data_size; - dst_data = reduced_data->data() + grad_offset; - src_data = const_cast(grads); - MS_EXCEPTION_IF_NULL(dst_data); - MS_EXCEPTION_IF_NULL(src_data); - auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; - } - - // Fill the reduced indice - int64_t indice_offset = grad_offset + lengths[grad_index]; - data_size = lengths[indice_index] * sizeof(T); - T *indice_data = reduced_data->data() + indice_offset; - dst_size = data_size; - src_size = data_size; - dst_data = indice_data; - src_data = indices; - MS_EXCEPTION_IF_NULL(dst_data); - MS_EXCEPTION_IF_NULL(src_data); - ret = memcpy_s(dst_data, dst_size, src_data, src_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; - } -} - -template -void WorkerProxy::BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attr) { - MS_EXCEPTION_IF_NULL(sliced); - sliced->resize(server_num_); - for (int64_t i = 0; i < server_num_; i++) { - sliced->at(i).first = true; - sliced->at(i).second = send; - expected_result_count_[timestamp] += 1; - } -} - -template -void WorkerProxy::RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attr) { - MS_EXCEPTION_IF_NULL(sliced); - sliced->resize(server_num_); - auto keys = send.keys; - auto vals = send.vals; - auto lens = send.lens; - - int64_t server_id, len; - ::ps::Key param_key; - for (size_t i = 0; i < keys.size(); i++) { - param_key = keys[i]; - server_id = key_to_server_id_[param_key]; - if (!sliced->at(server_id).first) { - sliced->at(server_id).first = true; - expected_result_count_[timestamp] += 1; - } - - ::ps::KVPairs &server_kv_pairs = sliced->at(server_id).second; - server_kv_pairs.keys.push_back(param_key); - if (vals.empty()) { - continue; - } - - len = lens[i]; - int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0); - auto val_begin = vals.begin() + offset; - auto val_end = val_begin + len; - - for (auto iter = val_begin; iter != val_end; iter++) { - server_kv_pairs.vals.push_back(*iter); - } - server_kv_pairs.lens.push_back(len); - } -} - -template -void WorkerProxy::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs &send, - const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attrs) { - MS_EXCEPTION_IF_NULL(sliced); - sliced->resize(server_num_); - auto keys = send.keys; - auto vals = send.vals; - auto lens = send.lens; - - size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]]; - const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[keys[0]]); - for (size_t i = 0; i < ranges.size(); i++) { - size_t offset_begin = ranges[i].begin() * col_cnt; - size_t offset_end = (ranges[i].end() + 1) * col_cnt; - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.vals = vals.segment(offset_begin, offset_end); - kvs.lens.push_back(offset_end - offset_begin); - sliced->at(i).first = true; - sliced->at(i).second = kvs; - } -} - -template -void WorkerProxy::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, - const std::vector<::ps::Range> &, - std::vector>> *sliced, - const std::map &attrs) { - MS_EXCEPTION_IF_NULL(sliced); - T *embedding_vals = send.vals.data(); - int *lookup_ids = send.lens.data(); - size_t val_size = send.vals.size(); - size_t id_size = send.lens.size(); - size_t embedding_dim = val_size / id_size; - - const Key &key = send.keys[0]; - const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); - sliced->resize(ranges.size()); - - for (size_t i = 0; i < ranges.size(); i++) { - const ::ps::Range &range = ranges[i]; - const auto &begin = range.begin(); - const auto &end = range.end(); - auto &kvs = sliced->at(i).second; - kvs.keys.push_back(key); - for (size_t j = 0; j < id_size; j++) { - auto lookup_id = static_cast(lookup_ids[j]); - if (lookup_id >= begin && lookup_id <= end) { - kvs.keys.push_back(lookup_id); - for (size_t k = 0; k < embedding_dim; k++) { - kvs.vals.push_back(embedding_vals[j * embedding_dim + k]); - } - } - } - - if (kvs.keys.size() <= 1) { - sliced->at(i).first = false; - } else { - sliced->at(i).first = true; - expected_result_count_[timestamp] += 1; - } - } -} - -template -void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { - int64_t ts = msg.meta.timestamp; - if (msg.meta.pull) { - CHECK_GE(msg.data.size(), (size_t)2); - ::ps::KVPairs kvs; - kvs.keys = msg.data[0]; - kvs.vals = msg.data[1]; - if (msg.data.size() > (size_t)2) { - kvs.lens = msg.data[2]; - } - mutex_.lock(); - lookup_results_[ts].push_back(kvs); - mutex_.unlock(); - } - if (lookup_customer_->NumResponse(ts) + 1 == server_num_) { - const auto &cb = lookup_callbacks_[ts]; - cb(); - lookup_callbacks_.erase(ts); - } -} - -template -void WorkerProxy::ProcessResponse(const ::ps::Message &msg) { - int64_t ts = msg.meta.timestamp; - - if (msg.meta.pull) { - CHECK_GE(msg.data.size(), (size_t)2); - ::ps::KVPairs kvs; - kvs.keys = msg.data[0]; - kvs.vals = msg.data[1]; - if (msg.data.size() > (size_t)2) { - kvs.lens = msg.data[2]; - } - mutex_.lock(); - int rsp_server_rank = ::ps::Postoffice::Get()->IDtoRank(msg.meta.sender); - gathered_response_[ts][rsp_server_rank] = kvs; - mutex_.unlock(); - if (general_customer_->NumResponse(ts) + 1 == server_num_) { - const auto &cb = general_callbacks_[ts]; - cb(); - general_callbacks_.erase(ts); - } - } -} - -template -void WorkerProxy::Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, - const ::ps::KVPairs &kvs, const Slicer &slicer, std::map attrs) { - MS_EXCEPTION_IF_NULL(customer); - SlicedKVs sliced; - slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs); - - for (size_t i = 0; i < sliced.size(); i++) { - const auto &s = sliced[i]; - if (!s.first) continue; - ::ps::Message msg; - msg.meta.app_id = customer->app_id(); - msg.meta.customer_id = customer->customer_id(); - msg.meta.request = true; - msg.meta.push = push; - msg.meta.pull = pull; - msg.meta.head = cmd; - msg.meta.timestamp = timestamp; - msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); - msg.meta.priority = kvs.priority; - const auto &kvs = s.second; - if (kvs.keys.size()) { - msg.AddData(kvs.keys); - msg.AddData(kvs.vals); - if (kvs.lens.size()) { - msg.AddData(kvs.lens); - } - } - ::ps::Postoffice::Get()->van()->Send(msg); - } -} -} // namespace ps -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_WORKER_PROXY_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc index f7263d72df..a3677bf15e 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace device { void KernelRuntimeManager::ClearRuntimeResource() { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { + if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::ps_cache_instance.SyncEmbeddingTable(); } #endif diff --git a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh index 180c27b417..bfd60cd98c 100644 --- a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh +++ b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train.sh @@ -78,7 +78,6 @@ export DEVICE_NUM=8 export RANK_SIZE=8 export RANK_TABLE_FILE=$PATH1 -export MS_COMM_TYPE=zmq export MS_SCHED_NUM=1 export MS_WORKER_NUM=$RANK_SIZE export MS_SERVER_NUM=8 diff --git a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh index 31b499a274..ca3a654ab7 100755 --- a/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh +++ b/model_zoo/official/cv/resnet/scripts/run_parameter_server_train_gpu.sh @@ -70,7 +70,6 @@ fi export DEVICE_NUM=8 export RANK_SIZE=8 -export MS_COMM_TYPE=zmq export MS_SCHED_NUM=1 export MS_WORKER_NUM=8 export MS_SERVER_NUM=8 diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh index 64a1f1ed76..e13d36f8d1 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_cluster.sh @@ -27,7 +27,6 @@ export EPOCH_SIZE=$2 export DEVICE_TARGET=$3 export DATASET=$4 -export MS_COMM_TYPE=zmq export MS_SCHED_NUM=1 export MS_WORKER_NUM=$RANK_SIZE export LOCAL_WORKER_NUM=$5 diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh index 73db06fc69..8c89f95f39 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh @@ -25,7 +25,6 @@ export RANK_SIZE=$1 export EPOCH_SIZE=$2 export DEVICE_TARGET=$3 export DATASET=$4 -export MS_COMM_TYPE=zmq export MS_SCHED_NUM=1 export MS_WORKER_NUM=$RANK_SIZE export MS_SERVER_NUM=$5 diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh index 8c7bb208c8..a68a2cbd4b 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_train_standalone.sh @@ -23,7 +23,6 @@ self_path=$(dirname "${script_self}") export EPOCH_SIZE=$1 export DEVICE_TARGET=$2 export DATASET=$3 -export MS_COMM_TYPE=zmq export MS_SCHED_NUM=1 export MS_WORKER_NUM=1 export MS_SERVER_NUM=$4 diff --git a/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh b/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh index b9a7d3f634..44b294bf1d 100644 --- a/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh +++ b/tests/st/ps/cmp_sparse_embedding/shell_run_test.sh @@ -15,8 +15,7 @@ # ============================================================================ execute_path=$(pwd) -self_path=$(dirname "${script_self}") -export MS_COMM_TYPE=zmq +self_path=$(dirname $0) export MS_SCHED_NUM=1 DEVICE_TARGET=$1 export MS_WORKER_NUM=$2 diff --git a/tests/st/ps/full_ps/shell_run_test.sh b/tests/st/ps/full_ps/shell_run_test.sh index 8222e76888..e61c9325fe 100644 --- a/tests/st/ps/full_ps/shell_run_test.sh +++ b/tests/st/ps/full_ps/shell_run_test.sh @@ -15,8 +15,7 @@ # ============================================================================ execute_path=$(pwd) -self_path=$(dirname "${script_self}") -export MS_COMM_TYPE=zmq +self_path=$(dirname $0) export MS_SCHED_NUM=1 DEVICE_TARGET=$1 DATASET_PATH=$2 diff --git a/tests/st/ps/multi_full_ps/shell_run_test.sh b/tests/st/ps/multi_full_ps/shell_run_test.sh index 564b7ce444..0a4adc393c 100644 --- a/tests/st/ps/multi_full_ps/shell_run_test.sh +++ b/tests/st/ps/multi_full_ps/shell_run_test.sh @@ -15,8 +15,7 @@ # ============================================================================ execute_path=$(pwd) -self_path=$(dirname "${script_self}") -export MS_COMM_TYPE=zmq +self_path=$(dirname $0) export MS_SCHED_NUM=1 DEVICE_TARGET=$1 export MS_WORKER_NUM=$2 diff --git a/tests/st/ps/part_ps/shell_run_test.sh b/tests/st/ps/part_ps/shell_run_test.sh index 1a2d5e6aee..ca42c5daf7 100644 --- a/tests/st/ps/part_ps/shell_run_test.sh +++ b/tests/st/ps/part_ps/shell_run_test.sh @@ -15,8 +15,7 @@ # ============================================================================ execute_path=$(pwd) -self_path=$(dirname "${script_self}") -export MS_COMM_TYPE=zmq +self_path=$(dirname $0) export MS_SCHED_NUM=1 DEVICE_TARGET=$1 DATASET_PATH=$2 diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 72113f19f8..0eb4a31cb1 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -150,6 +150,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parame list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/worker.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_server.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") diff --git a/third_party/patch/pslite/ps_lite.patch001 b/third_party/patch/pslite/ps_lite.patch001 deleted file mode 100644 index e2e51e93c8..0000000000 --- a/third_party/patch/pslite/ps_lite.patch001 +++ /dev/null @@ -1,255 +0,0 @@ -diff -Npur ps-lite-master/include/dmlc/base.h ps-lite-master-new/include/dmlc/base.h ---- ps-lite-master/include/dmlc/base.h 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/include/dmlc/base.h 2020-07-01 11:56:50.444833389 +0800 -@@ -8,7 +8,7 @@ - - /*! \brief whether use glog for logging */ - #ifndef DMLC_USE_GLOG --#define DMLC_USE_GLOG 0 -+#define DMLC_USE_GLOG 1 - #endif - - /*! -diff -Npur ps-lite-master/include/dmlc/logging.h ps-lite-master-new/include/dmlc/logging.h ---- ps-lite-master/include/dmlc/logging.h 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/include/dmlc/logging.h 2020-07-08 21:35:33.334584767 +0800 -@@ -52,7 +52,7 @@ struct Error : public std::runtime_error - - namespace dmlc { - inline void InitLogging(const char* argv0) { -- google::InitGoogleLogging(argv0); -+ //google::InitGoogleLogging(argv0); - } - } // namespace dmlc - -diff -Npur ps-lite-master/make/deps.mk ps-lite-master-new/make/deps.mk ---- ps-lite-master/make/deps.mk 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/make/deps.mk 2020-06-17 10:35:46.253837426 +0800 -@@ -1,69 +1,7 @@ - # Install dependencies -- --URL1=https://raw.githubusercontent.com/mli/deps/master/build --URL2=https://github.com/google/protobuf/releases/download/v3.5.1 --ifndef WGET --WGET = wget --endif -- --# protobuf --PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h --${PROTOBUF}: -- $(eval FILE=protobuf-cpp-3.5.1.tar.gz) -- $(eval DIR=protobuf-3.5.1) -- rm -rf $(FILE) $(DIR) -- $(WGET) $(URL2)/$(FILE) && tar --no-same-owner -zxf $(FILE) -- cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install -- rm -rf $(FILE) $(DIR) -- - # zmq --ZMQ = ${DEPS_PATH}/include/zmq.h -+ZMQ = $(MS_ZMQ_INSTALL_PATH)/lib/libzmq.a - - ${ZMQ}: -- $(eval FILE=zeromq-4.1.4.tar.gz) -- $(eval DIR=zeromq-4.1.4) -- rm -rf $(FILE) $(DIR) -- $(WGET) $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) -- cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install -- rm -rf $(FILE) $(DIR) -- --# lz4 --LZ4 = ${DEPS_PATH}/include/lz4.h --${LZ4}: -- $(eval FILE=lz4-r129.tar.gz) -- $(eval DIR=lz4-r129) -- rm -rf $(FILE) $(DIR) -- wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) -- cd $(DIR) && $(MAKE) && PREFIX=$(DEPS_PATH) $(MAKE) install -- rm -rf $(FILE) $(DIR) -- --# cityhash --CITYHASH = ${DEPS_PATH}/include/city.h --${CITYHASH}: -- $(eval FILE=cityhash-1.1.1.tar.gz) -- $(eval DIR=cityhash-1.1.1) -- rm -rf $(FILE) $(DIR) -- wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) -- cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --enable-sse4.2 && $(MAKE) CXXFLAGS="-g -O3 -msse4.2" && $(MAKE) install -- rm -rf $(FILE) $(DIR) -- -- --# # gflags --# ${DEPS_PATH}/include/google/gflags.h: --# $(eval FILE=gflags-2.0-no-svn-files.tar.gz) --# $(eval DIR=gflags-2.0) --# rm -rf $(FILE) $(DIR) --# wget $(URL)/$(FILE) && tar -zxf $(FILE) --# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install --# rm -rf $(FILE) $(DIR) --# gflags: | ${DEPS_PATH}/include/google/gflags.h -+ cd $(MS_ZMQ_DIR) && export CFLAGS="-fPIC -D_GLIBCXX_USE_CXX11_ABI=0" && export CXXFLAGS=-fPIC && ./configure -prefix=$(MS_ZMQ_INSTALL_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install - --# # glog --# ${DEPS_PATH}/include/glog/logging.h: | ${DEPS_PATH}/include/google/gflags.h --# $(eval FILE=v0.3.4.tar.gz) --# $(eval DIR=glog-0.3.4) --# rm -rf $(FILE) $(DIR) --# wget https://github.com/google/glog/archive/$(FILE) && tar -zxf $(FILE) --# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --with-gflags=$(DEPS_PATH) && $(MAKE) && $(MAKE) install --# rm -rf $(FILE) $(DIR) --# glog: | ${DEPS_PATH}/include/glog/logging.h -diff -Npur ps-lite-master/make/ps.mk ps-lite-master-new/make/ps.mk ---- ps-lite-master/make/ps.mk 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/make/ps.mk 2020-06-05 09:28:35.337740291 +0800 -@@ -9,5 +9,5 @@ ifeq ($(USE_KEY32), 1) - ADD_CFLAGS += -DUSE_KEY32=1 - endif - --PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lprotobuf-lite -lzmq --PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libprotobuf-lite.a libzmq.a) -+PS_LDFLAGS_SO = -L$(MS_ZMQ_INSTALL_PATH)/lib -lzmq -L$(MS_PROTO_LIB_DIR) -lprotobuf-lite -+PS_LDFLAGS_A = $(addprefix $(MS_ZMQ_INSTALL_PATH)/lib -L$(MS_PROTO_LIB_DIR), libprotobuf-lite.a libzmq.a) -diff -Npur ps-lite-master/Makefile ps-lite-master-new/Makefile ---- ps-lite-master/Makefile 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/Makefile 2020-06-17 11:09:20.240322660 +0800 -@@ -12,13 +12,24 @@ ifndef DEPS_PATH - DEPS_PATH = $(shell pwd)/deps - endif - -+MS_PROTO_DIR = @protobuf_DIRPATH@ -+MS_GLOG_DIR = @glog_DIRPATH@ -+MS_ZMQ_DIR = @zeromq_DIRPATH@ -+ -+MS_PROTO_LIB_DIR = @protobuf_LIBPATH@ -+MS_GLOG_LIB_DIR = @glog_LIBPATH@ -+MS_ZMQ_INSTALL_PATH = $(MS_ZMQ_DIR)/zmq_install - - ifndef PROTOC --PROTOC = ${DEPS_PATH}/bin/protoc -+PROTOC = $(MS_PROTO_DIR)/bin/protoc - endif - --INCPATH = -I./src -I./include -I$(DEPS_PATH)/include --CFLAGS = -std=c++11 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) -+INCPATH = -I./src -I./include -I$(MS_ZMQ_INSTALL_PATH)/include -+INCPATH += -I$(MS_PROTO_DIR)/include -+INCPATH += -I$(MS_GLOG_DIR)/include -+ -+CXXFLAGS = -D_GLIBCXX_USE_CXX11_ABI=0 -+CFLAGS = -std=c++11 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) -D_GLIBCXX_USE_CXX11_ABI=0 - LIBS = -pthread - - ifdef USE_IBVERBS -@@ -30,6 +41,7 @@ ifdef ASAN - CFLAGS += -fsanitize=address -fno-omit-frame-pointer -fno-optimize-sibling-calls - endif - -+LIBS += -L$(MS_GLOG_LIB_DIR) -lglog - - all: ps test - -@@ -51,9 +63,9 @@ build/libps.a: $(OBJS) - build/%.o: src/%.cc ${ZMQ} src/meta.pb.h - @mkdir -p $(@D) - $(CXX) $(INCPATH) -std=c++11 -MM -MT build/$*.o $< >build/$*.d -- $(CXX) $(CFLAGS) $(LIBS) -c $< -o $@ -+ $(CXX) $(CFLAGS) $(CXXFLAGS) $(LIBS) -c $< -o $@ - --src/%.pb.cc src/%.pb.h : src/%.proto ${PROTOBUF} -+src/%.pb.cc src/%.pb.h : src/%.proto - $(PROTOC) --cpp_out=./src --proto_path=./src $< - - -include build/*.d -diff -Npur ps-lite-master/src/ibverbs_van.h ps-lite-master-new/src/ibverbs_van.h ---- ps-lite-master/src/ibverbs_van.h 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/src/ibverbs_van.h 2020-06-02 20:52:11.076230014 +0800 -@@ -145,15 +145,15 @@ class SimpleMempool { - total_allocated_size += new_mem_size; - } - -- CHECK_NE(free_list.end(), it) << "Not enough memory"; -+ //CHECK_NE(free_list.end(), it) << "Not enough memory"; - CHECK_GE(it->first, proper_size); - - char *addr = it->second; - size_t space_left = it->first - proper_size; - - free_list.erase(it); -- CHECK_EQ(used_list.find(addr), used_list.end()) -- << "Address is already allocated"; -+ //CHECK_EQ(used_list.find(addr), used_list.end()) -+ //<< "Address is already allocated"; - - used_list.emplace(addr, proper_size); - -@@ -173,8 +173,8 @@ class SimpleMempool { - std::lock_guard lk(mu_); - - auto it = used_list.find(addr); -- CHECK_NE(used_list.end(), it) -- << "Cannot find info about address: " << (uintptr_t)addr; -+ //CHECK_NE(used_list.end(), it) -+ //<< "Cannot find info about address: " << (uintptr_t)addr; - - size_t size = it->second; - used_list.erase(it); -@@ -208,7 +208,7 @@ class SimpleMempool { - // Convert the memory address to its associated RDMA memory region - inline struct ibv_mr *Addr2MR(char *addr) { - auto it = mr_list.lower_bound(addr); -- CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; -+ //CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; - return it->second; - } - }; -@@ -330,7 +330,7 @@ class AddressPool { - CHECK(ptr); - uint32_t idx = indices_.front(); - indices_.pop(); -- CHECK_EQ(table_[idx], nullptr); -+ //CHECK_EQ(table_[idx], nullptr); - table_[idx] = ptr; - return idx; - } -@@ -636,7 +636,7 @@ class IBVerbsVan : public Van { - PBMeta meta; - PackMetaPB(msg.meta, &meta); - -- CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); -+ //CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); - Endpoint *endpoint = endpoints_[remote_id].get(); - MessageBuffer *msg_buf = new MessageBuffer(); - -diff -Npur ps-lite-master/src/van.cc ps-lite-master-new/src/van.cc ---- ps-lite-master/src/van.cc 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/src/van.cc 2020-06-02 20:52:43.330405828 +0800 -@@ -448,6 +448,7 @@ void Van::PackMetaPB(const Meta& meta, P - if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp); - if (meta.body.size()) pb->set_body(meta.body); - pb->set_push(meta.push); -+ pb->set_pull(meta.pull); - pb->set_request(meta.request); - pb->set_simple_app(meta.simple_app); - pb->set_priority(meta.priority); -diff -Npur ps-lite-master/tests/test.mk ps-lite-master-new/tests/test.mk ---- ps-lite-master/tests/test.mk 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/tests/test.mk 2020-06-16 19:15:06.025087897 +0800 -@@ -1,10 +1,10 @@ --TEST_SRC = $(wildcard tests/test_*.cc) --TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC)) -+#TEST_SRC = $(wildcard tests/test_*.cc) -+#TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC)) - --# -ltcmalloc_and_profiler --LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread --tests/% : tests/%.cc build/libps.a -- $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d -- $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS) -- ---include tests/*.d -+## -ltcmalloc_and_profiler -+#LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread -+#tests/% : tests/%.cc build/libps.a -+# $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d -+# $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS) -+# -+#-include tests/*.d