[Feature] one ps (3/4) (#29604)

* oneps (3/4)

Co-authored-by: MrChengmo <cmchengmo@163.com>
Co-authored-by: malin10 <malin10@baidu.com>
Co-authored-by: chengmo <chengmo@baidu.com>
revert-31562-mean
tangwei12 5 years ago committed by GitHub
parent edc06c6a1b
commit 032414ca2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -246,17 +246,6 @@ endif()
include(third_party) # download, build, install third_party, Contains about 20+ dependencies
if(WITH_DISTRIBUTE)
if(WITH_GRPC)
message(STATUS "Use grpc framework.")
include(external/grpc)
else()
message(STATUS "Use brpc framework.")
include(external/leveldb)
include(external/brpc)
endif()
endif()
include(flags) # set paddle compile flags
if(WITH_PROFILER)

@ -14,7 +14,7 @@
INCLUDE(ExternalProject)
find_package(OpenSSL REQUIRED)
find_package(OpenSSL REQUIRED)
message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})
@ -33,39 +33,43 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
GIT_REPOSITORY "${GIT_URL}/apache/incubator-brpc.git"
GIT_TAG "ad00fe940b4f05225b214131959293bbed8744a0" #rdma branch's head now.
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_INSTALL_PREFIX=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS}
LIST_SEPARATOR |
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/wangjiawei04/brpc"
GIT_TAG "6d79e0b17f25107c35b705ea58d888083f59ff47"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_INSTALL_PREFIX=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS}
LIST_SEPARATOR |
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
ADD_DEPENDENCIES(extern_brpc protobuf ssl crypto leveldb gflags glog gtest)
# ADD_DEPENDENCIES(extern_brpc protobuf ssl crypto leveldb gflags glog gtest snappy)
ADD_DEPENDENCIES(extern_brpc protobuf ssl crypto leveldb gflags glog snappy)
ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG)
LIST(APPEND external_project_dependencies brpc)

@ -21,20 +21,25 @@ SET(LEVELDB_LIBRARIES "${LEVELDB_INSTALL_DIR}/lib/libleveldb.a" CACHE FILEPATH "
INCLUDE_DIRECTORIES(${LEVELDB_INCLUDE_DIR})
ExternalProject_Add(
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
PREFIX ${LEVELDB_SOURCES_DIR}
GIT_REPOSITORY "${GIT_URL}/google/leveldb.git"
GIT_TAG v1.18
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR}
GIT_REPOSITORY "https://github.com/google/leveldb"
GIT_TAG v1.18
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
&& cp ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/libleveldb.a ${LEVELDB_LIBRARIES}
&& cp -r ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/include ${LEVELDB_INSTALL_DIR}/
BUILD_IN_SOURCE 1
BUILD_IN_SOURCE 1
)
ADD_DEPENDENCIES(extern_leveldb snappy)
ADD_LIBRARY(leveldb STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET leveldb PROPERTY IMPORTED_LOCATION ${LEVELDB_LIBRARIES})
ADD_DEPENDENCIES(leveldb extern_leveldb)
LIST(APPEND external_project_dependencies leveldb)

@ -0,0 +1,71 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include (ExternalProject)
# NOTE: snappy is needed when linking with recordio
set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy)
set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE)
if(WIN32)
SET(SNAPPY_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4244 /wd4267")
else()
SET(SNAPPY_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
endif()
ExternalProject_Add(
extern_snappy
GIT_REPOSITORY "https://github.com/google/snappy"
GIT_TAG "1.1.7"
PREFIX ${SNAPPY_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${SNAPPY_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DSNAPPY_BUILD_TESTS:BOOL=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
add_custom_command(TARGET extern_snappy POST_BUILD
COMMAND cmake -E copy ${SNAPPY_INSTALL_DIR}/lib/snappy.lib ${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib
)
ENDIF()
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.lib")
else(WIN32)
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
endif (WIN32)
add_library(snappy STATIC IMPORTED GLOBAL)
set_property(TARGET snappy PROPERTY IMPORTED_LOCATION ${SNAPPY_LIBRARIES})
include_directories(${SNAPPY_INCLUDE_DIR})
add_dependencies(snappy extern_snappy)

@ -95,7 +95,7 @@ include_directories("${PADDLE_SOURCE_DIR}/paddle/fluid/framework/io")
if(NOT APPLE)
find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT})
if(WITH_PSLIB)
if(WITH_PSLIB OR WITH_DISTRIBUTE)
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt -lz -lssl")
else()
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt")

@ -233,7 +233,7 @@ if(WITH_PYTHON)
list(APPEND third_party_deps extern_pybind)
endif()
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING OR WITH_DISTRIBUTE)
include(external/gtest) # download, build, install gtest
list(APPEND third_party_deps extern_gtest)
ENDIF()
@ -275,14 +275,18 @@ if(WITH_BOX_PS)
list(APPEND third_party_deps extern_box_ps)
endif(WITH_BOX_PS)
if(WITH_DISTRIBUTE)
if (WITH_DISTRIBUTE)
include(external/snappy)
list(APPEND third_party_deps extern_snappy)
if(WITH_GRPC)
list(APPEND third_party_deps extern_grpc)
else()
list(APPEND third_party_deps extern_leveldb)
list(APPEND third_party_deps extern_brpc)
endif()
include(external/leveldb)
list(APPEND third_party_deps extern_leveldb)
include(external/brpc)
list(APPEND third_party_deps extern_brpc)
include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)
endif()
if(WITH_XBYAK)

@ -14,14 +14,9 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
add_subdirectory(table)
add_subdirectory(test)
# open it until CI support brpc
return()
add_subdirectory(service)
add_subdirectory(test)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)

@ -35,6 +35,6 @@ cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})

@ -741,7 +741,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < ids.size(); ++i) {
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;

@ -839,7 +839,7 @@ void GeoCommunicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) return;
if (!ctx.is_sparse) continue;
auto &varname = ctx.origin_varnames[0];
auto &table_id = ctx.table_id;
auto param = varname.substr(0, varname.size() - 5);
@ -853,12 +853,12 @@ void GeoCommunicator::InitDense(std::vector<std::string> &varnames,
if (trainer_id_ == 0) {
RpcSendDenseParam(varnames, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(0) << "push dense param to table " << table_id
VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(0) << "push dense param to table " << table_id
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
@ -952,20 +952,20 @@ void GeoCommunicator::RecvDense(const CommContext &send_ctx) {
}
void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) {
VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " begin.";
VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " begin.";
if (trainer_id_ == 0) {
RpcSendSparseParam(var_name, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(0) << "push sparse param to table " << table_id
VLOG(1) << "push sparse param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvSparse(var_name, table_id, recv_scope_);
VLOG(0) << "push dense param to table " << table_id
VLOG(1) << "pull sparse param to table " << table_id
<< " from 0' trainer done";
}
VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " done.";
VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " done.";
auto *global_var = recv_scope_->FindVar(var_name);
auto *var = old_scope_->Var(var_name);
framework::CopyVariable(*global_var, var);

@ -24,11 +24,11 @@
#include "paddle/fluid/platform/timer.h"
DECLARE_int32(rpc_deadline);
DECLARE_int32(pserver_timeout_ms);
namespace paddle {
namespace distributed {
DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms");
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;
@ -53,6 +53,23 @@ void HeterClient::Stop() {
}
}
void HeterClient::FinalizeWorker() {
running_ = false;
if (!is_initialized_) {
VLOG(0) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(1) << "HeterClient Stop Done";
}
}
std::future<int32_t> HeterClient::StopHeterWorker() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
void HeterClient::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
@ -73,7 +90,7 @@ void HeterClient::CreateClient2XpuConnection() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = pserver_timeout_ms;
options.timeout_ms = FLAGS_pserver_timeout_ms;
xpu_channels_.resize(xpu_list_.size());
for (size_t i = 0; i < xpu_list_.size(); ++i) {
@ -102,7 +119,7 @@ void HeterClient::SendAndRecvAsync(
int num = trainer_id_ % xpu_channels_.size();
brpc::Controller cntl;
cntl.set_timeout_ms(pserver_timeout_ms);
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get());
@ -149,7 +166,7 @@ std::future<int32_t> HeterClient::SendCmd(
}
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
pserver_timeout_ms); // cmd msg don't limit timeout for save/load
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}

@ -42,7 +42,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
class OnHeterRpcDone : public google::protobuf::Closure {
public:
OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
explicit OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
virtual ~OnHeterRpcDone() {}
void Run() {
std::unique_ptr<OnHeterRpcDone> self_guard(this);
@ -79,7 +79,6 @@ class HeterClient {
if (NULL == s_instance_) {
is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient());
std::vector<std::string> xpu_list = {endpoint};
s_instance_->SetXpuList(endpoint);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
@ -89,6 +88,8 @@ class HeterClient {
void Stop();
void FinalizeWorker();
void MainThread();
void RpcProfilerControl();
@ -97,6 +98,7 @@ class HeterClient {
const std::vector<std::string>& params);
std::future<int32_t> StartProfiler();
std::future<int32_t> StopProfiler();
std::future<int32_t> StopHeterWorker();
@ -104,17 +106,16 @@ class HeterClient {
void SetXpuList(const std::vector<std::string>& xpu_list) {
xpu_list_ = xpu_list;
};
}
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
private:
static std::shared_ptr<HeterClient> s_instance_;
protected:
static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;

@ -45,7 +45,11 @@ void HeterServer::StartHeterService() {
}
condition_ready_.notify_all();
server_.Join();
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
}
void HeterServer::SetEndPoint(std::string& endpoint) {
@ -83,6 +87,7 @@ int32_t HeterService::stop_heter_worker(const PsRequestMessage& request,
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
VLOG(0) << "Stop heter Service done.";
}
return 0;
}

@ -20,6 +20,7 @@ limitations under the License. */
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace distributed {
@ -82,7 +84,7 @@ class HeterService : public ::paddle::PsService {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
};
}
void SendAndRecvVariable(::google::protobuf::RpcController* controller,
const MultiVarMsg* request, MultiVarMsg* response,
@ -134,6 +136,10 @@ class HeterServer {
virtual ~HeterServer() {}
void Stop() {
VLOG(0) << "HeterServer Stop()";
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
server_.Stop(1000);
server_.Join();
}
@ -162,6 +168,10 @@ class HeterServer {
private:
static std::shared_ptr<HeterServer> s_instance_;
mutable std::mutex mutex_;
std::condition_variable cv_;
std::condition_variable condition_ready_;
bool stoped_ = false;
std::string endpoint_;
protected:
@ -169,7 +179,7 @@ class HeterServer {
HeterService service_;
DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
@ -215,6 +225,7 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) override {
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
FLAGS_eager_delete_tensor_gb = -1;
auto& local_scope = scope_->NewScope();
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();

@ -60,6 +60,8 @@ int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size();
const auto &downpour_param = _config.downpour_server_param();
uint32_t barrier_table = UINT32_MAX;
@ -72,6 +74,7 @@ int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
"BarrierTable") {
barrier_table = downpour_param.downpour_table_param(i).table_id();
}
table->set_shard(_rank, shard_num);
table->initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);

@ -12,8 +12,7 @@ cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(tensor_accessor SRCS tensor_accessor.cc DEPS ${TABLE_DEPS} eigen3 ps_framework_proto device_context)
cc_library(tensor_table SRCS tensor_table.cc DEPS ps_framework_proto proto_desc enforce executor tensor device_context simple_threadpool gflags glog )
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(table SRCS table.cc DEPS common_table tensor_table tensor_accessor ps_framework_proto string_helper device_context gflags glog boost)
cc_library(table SRCS table.cc DEPS common_table tensor_accessor ps_framework_proto string_helper device_context gflags glog boost)

@ -251,6 +251,30 @@ int32_t CommonSparseTable::initialize_value() {
auto shard = std::make_shared<ValueBlock>(common, &initializers_);
shard_values_.emplace_back(shard);
}
auto accessor = _config.accessor();
std::vector<uint64_t> feasigns;
for (size_t x = 0; x < accessor.fea_dim(); ++x) {
if (x % _shard_num == _shard_idx) {
feasigns.push_back(x);
}
}
VLOG(0) << "has " << feasigns.size() << " ids need to be pre inited";
auto buckets = bucket(feasigns.size(), 10);
for (int x = 0; x < 10; ++x) {
auto bucket_feasigns = buckets[x + 1] - buckets[x];
std::vector<uint64_t> ids(bucket_feasigns);
std::copy(feasigns.begin() + buckets[x], feasigns.begin() + buckets[x + 1],
ids.begin());
std::vector<float> pulls;
pulls.resize(bucket_feasigns * param_dim_);
pull_sparse(pulls.data(), ids.data(), bucket_feasigns);
}
return 0;
}

@ -34,6 +34,18 @@ class Initializer {
virtual float GetValue() = 0;
virtual void GetValue(std::vector<float> *values, int numel) {
for (int x = 0; x < numel; ++x) {
values->push_back(GetValue());
}
}
virtual void GetValue(float *value, int numel) {
for (int x = 0; x < numel; ++x) {
value[x] = GetValue();
}
}
virtual ~Initializer() {}
protected:
@ -54,6 +66,11 @@ class UniformInitializer : public Initializer {
}
float GetValue() override { return dist_(*random_engine_); }
void GetValue(float *value, int numel) {
for (int x = 0; x < numel; ++x) {
value[x] = dist_(*random_engine_);
}
}
private:
float min_;
@ -77,6 +94,11 @@ class GaussianInitializer : public Initializer {
}
float GetValue() override { return dist_(*random_engine_); }
void GetValue(float *value, int numel) {
for (int x = 0; x < numel; ++x) {
value[x] = dist_(*random_engine_);
}
}
private:
float std_;
@ -94,6 +116,7 @@ class FillConstantInitializer : public Initializer {
}
float GetValue() override { return value_; }
void GetValue(float *value, int numel) { std::fill_n(value, numel, value_); }
private:
float value_;

@ -68,7 +68,7 @@ inline bool entry<float>(const int count, const float threshold) {
struct VALUE {
explicit VALUE(const std::vector<std::string> &names)
: names_(names), count_(0), unseen_days_(0) {
: names_(names), count_(1), unseen_days_(0), seen_after_last_save_(true) {
values_.resize(names.size());
for (int i = 0; i < static_cast<int>(names.size()); i++) {
places[names[i]] = i;
@ -79,6 +79,14 @@ struct VALUE {
values_ = std::move(*values);
}
void set(const std::vector<Initializer *> &inits, std::vector<int> numels) {
for (int x = 0; x < numels.size(); ++x) {
auto &value = values_[x];
value.resize(numels[x]);
inits[x]->GetValue(value.data(), numels[x]);
}
}
void set(const std::vector<std::string> &names,
const std::vector<std::vector<float>> &values) {
for (int i = 0; i < static_cast<int>(names.size()); i++) {
@ -117,8 +125,8 @@ struct VALUE {
std::vector<std::string> names_;
int count_;
bool seen_after_last_save_;
int unseen_days_;
bool seen_after_last_save_;
bool is_entry_;
std::vector<std::vector<float>> values_;
std::unordered_map<std::string, int> places;
@ -139,15 +147,20 @@ class ValueBlock {
value_dims_.push_back(dim);
}
for (auto &name : value_names_) {
initializer_list_.emplace_back(initializers_->at(name));
}
// for Entry
{
// entry will add later
std::string entry_attr = "none";
if (entry_attr == "none") {
has_entry = false;
entry_func_ =
std::bind(entry<std::string>, std::placeholders::_1, "none");
} else {
has_entry = true;
auto slices = string::split_string<std::string>(entry_attr, "&");
if (slices[0] == "count_filter") {
int threshold = std::stoi(slices[1]);
@ -181,6 +194,22 @@ class ValueBlock {
values_[id] = value;
}
void Init(const uint64_t &id, const std::vector<Initializer *> &inits,
int count) {
if (Has(id)) {
PADDLE_THROW(platform::errors::AlreadyExists("id already exist, error"));
}
if (inits.size() != value_names_.size()) {
PADDLE_THROW(
platform::errors::AlreadyExists("values can not match, error"));
}
auto value = new VALUE(value_names_);
value->set(inits, value_dims_);
values_[id] = value;
}
std::vector<std::vector<float> *> Get(
const uint64_t &id, const std::vector<std::string> &value_names) {
auto ret_values = values_.at(id)->get(value_names);
@ -195,27 +224,12 @@ class ValueBlock {
void InitFromInitializer(const uint64_t &id,
const std::vector<std::string> &value_names) {
if (Has(id)) {
Update(id);
return;
}
auto rets = std::vector<std::vector<float>>();
rets.resize(value_names_.size());
for (int i = 0; i < static_cast<int>(value_names_.size()); i++) {
auto name = value_names_[i];
auto *init = initializers_->at(name);
auto dim = value_dims_[i];
rets[i].resize(dim);
for (int j = 0; j < static_cast<int>(dim); j++) {
rets[i][j] = init->GetValue();
if (has_entry) {
Update(id);
}
return;
}
Init(id, &rets, 0);
Update(id);
Init(id, initializer_list_, 1);
}
bool GetEntry(const uint64_t &id) {
@ -254,10 +268,12 @@ class ValueBlock {
std::unordered_map<uint64_t, VALUE *> values_;
private:
bool has_entry = false;
std::vector<std::string> value_names_;
std::vector<int> value_dims_;
std::function<bool(uint64_t)> entry_func_;
std::unordered_map<std::string, Initializer *> *initializers_;
std::vector<Initializer *> initializer_list_;
};
} // namespace distributed

@ -22,14 +22,12 @@
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/tensor_accessor.h"
#include "paddle/fluid/distributed/table/tensor_table.h"
namespace paddle {
namespace distributed {
REGISTER_CLASS(Table, CommonDenseTable);
REGISTER_CLASS(Table, CommonSparseTable);
REGISTER_CLASS(Table, DenseTensorTable);
REGISTER_CLASS(Table, SparseGeoTable);
REGISTER_CLASS(Table, BarrierTable);

@ -1,26 +1,12 @@
if(APPLE)
return()
endif()
set_source_files_properties(table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(table_test SRCS table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(dense_table_test SRCS dense_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(sparse_table_test SRCS sparse_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(geo_table_test SRCS geo_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
# open it until CI support brpc
return()
set_source_files_properties(brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_service_dense_sgd_test SRCS brpc_service_dense_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})

@ -120,7 +120,7 @@ TEST(CommonDenseTable, Adam) {
beta2_pow[0] *= beta2;
}
for (int j = 0; j < fea_dim; j++) {
ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-6);
ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5);
}
}

@ -62,7 +62,7 @@ TEST(SparseGeoTable, SSUM) {
std::vector<float> pull_values(init_values.size());
table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size());
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-6);
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5);
}
std::vector<std::vector<uint64_t>> trainer_keys;

@ -0,0 +1,71 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
TEST(BENCHMARK, LargeScaleKV) {
int emb_dim = 10;
int trainers = 2;
float beta1 = 0.9;
float beta2 = 0.999;
float epsilon = 1.0e-8;
TableParameter table_config;
table_config.set_table_class("CommonSparseTable");
FsClientParameter fs_config;
Table *table = new CommonSparseTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common();
common_config->set_name("adam");
common_config->set_table_name("adam_test_table");
common_config->set_trainer_num(trainers);
common_config->add_params("Param");
common_config->add_dims(emb_dim);
common_config->add_initializers("uniform_random&0&-1.0&1.0");
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
common_config->add_params("Moment1");
common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Moment2");
common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Beta1Pow");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
common_config->add_params("Beta2Pow");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
}
} // namespace distributed
} // namespace paddle

@ -216,18 +216,18 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
elseif(WITH_PSLIB)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
@ -239,11 +239,7 @@ elseif(WITH_PSLIB)
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc )
# TODO: Fix these unittest failed on Windows
# This unittest will always failed, now no CI will run this unittest
if(NOT WITH_MUSL AND NOT WIN32)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
@ -254,11 +250,6 @@ else()
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor)
# TODO: Fix these unittest failed on Windows
# This unittest will always failed, now no CI will run this unittest
if(NOT WITH_MUSL AND NOT WIN32)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
endif()
target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_helper conditional_block_op_helper)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save