* add service, remove ut on mac

* fix heter_profiler & add heter stop method

* fix code style
revert-31562-mean
tangwei12 5 years ago committed by GitHub
parent c0163837a5
commit 0034273b7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,3 +14,17 @@ endif()
add_subdirectory(table)
add_subdirectory(test)
# open it until CI support brpc
return()
add_subdirectory(service)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(fleet
SRCS fleet.cc
DEPS framework_proto ps_framework_proto ps_service variable_helper scope op_registry fs shell ${RPC_DEPS})
target_link_libraries(fleet z)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,246 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/distributed/service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace distributed {
using framework::LoDTensor;
using framework::Scope;
using framework::SelectedRows;
using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {
scale_sparse_gradient_with_batch_size_ = true;
// trainer sleep some time for pserver core dump
sleep_seconds_before_fail_exit_ = 300;
// pserver request server timeout ms
client2client_request_timeout_ms_ = 500000;
// pserver connect server timeout_ms
client2client_connect_timeout_ms_ = 10000;
// pserver request max retry
client2client_max_retry_ = 3;
}
// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
// Pull sparse variables from server in sync mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values,
int fea_dim,
const std::vector<std::string>& var_emb_names);
// Pull sparse variables from server in async mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values std::future
std::future<int32_t> PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_dim);
// Pull sparse variables from server in sync mode
// pull immediately to tensors
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<LoDTensor*>* outputs); // NOLINT
// pull dense variables from server in sync mod
// Param<in>: scope, table_id, var_names
// Param<out>: void
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// pull dense variables from server in async mod
// Param<in>: scope, table_id, var_names
// Param<out>: pull_dense_status
void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* pull_dense_status,
bool in_cpu);
// push dense parameters(not gradients) to server in sync mode
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* push_sparse_status,
float scale_datanorm, int batch_size);
// push dense variables to server in sync mode
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PushSparseVarsAsync(
const Scope& scope, const uint64_t table_id, const std::string& grad,
std::vector<std::future<int32_t>>* push_sparse_status);
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
// sparse_grad_names, batch_size, use_cvm, dump_slot
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in async mode
void PushSparseFromTensorWithLabelAsync(
const Scope& scope, const uint64_t table_id, int fea_dim,
uint64_t padding_id, bool scale_sparse, const std::string& accesor,
const std::string& click_name, platform::Place place,
const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<const LoDTensor*>* outputs); // NOLINT
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
// init server
void LoadSparseOnServer(const std::string& path, const std::string& meta,
uint32_t table_id);
// init server
// void InitServer(const std::string& dist_desc,
// const std::vector<uint64_t>& host_sign_list, int index);
void InitServer(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index);
// init trainer
void InitWorker(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, Scope* scope,
const RpcCtxMap& send_ctx,
const std::unordered_map<uint64_t, std::vector<std::string>>&
dense_varnames,
const std::map<std::string, std::string>& envs, int node_num,
int index);
// stop server
void StopServer();
// finalize worker to make worker can be stop
void FinalizeWorker();
// run server with ip port
uint64_t RunServer(const std::string& ip, uint32_t port);
// get client info
std::vector<uint64_t> GetClientsInfo();
// create client to client connection
void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// barrier with barrier table
void BarrierWithTable(uint32_t barrier_type);
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// clear all models, release their memory
void ClearModel();
// clear one table
void ClearOneTable(const uint64_t table_id);
// shrink sparse table
void ShrinkSparseTable(int table_id);
// shrink dense table
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay,
int emb_dim);
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
// register client to client communication
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
// send client to client message
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
// FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::FleetWrapper());
}
return s_instance_;
}
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
private:
static std::shared_ptr<FleetWrapper> s_instance_;
size_t GetAbsoluteSum(size_t start, size_t end, size_t level,
const framework::LoD& lod);
protected:
static bool is_initialized_;
std::map<uint64_t, std::vector<paddle::distributed::Region>> _regions;
bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_;
int client2client_request_timeout_ms_;
int client2client_connect_timeout_ms_;
int client2client_max_retry_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
} // end namespace distributed
} // end namespace paddle

@ -0,0 +1,40 @@
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
PROTO sendrecv.proto
DEPS ${BRPC_DEPS} )
set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS})
cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS})
cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})

File diff suppressed because it is too large Load Diff

@ -0,0 +1,212 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/ps_client.h"
namespace paddle {
namespace distributed {
class DownpourPsClientService : public PsService {
public:
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}
virtual int32_t configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
size_t _rank;
PSClient *_client;
};
class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourBrpcClosure() {}
virtual void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
template <class T>
struct array_deleter {
void operator()(T *&x) const { delete[] x; }
};
class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
// _running = false;
// try {
// _async_push_dense_thread.join();
// _async_push_sparse_thread.join();
//} catch (...) {
//}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
virtual std::future<int32_t> shrink(uint32_t table_id) override;
virtual std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> clear() override;
virtual std::future<int32_t> clear(uint32_t table_id) override;
virtual std::future<int32_t> stop_server() override;
virtual std::future<int32_t> start_profiler() override;
virtual std::future<int32_t> stop_profiler() override;
virtual void finalize_worker() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num);
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> flush();
virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override;
private:
virtual int32_t initialize() override;
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
inline brpc::Channel *get_sparse_channel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get();
}
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; //异步请求计数
std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
virtual size_t get_server_nums() { return _server_channels.size(); }
private:
int32_t start_client_service();
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
DownpourPsClientService _service;
std::atomic_uint grad_num_{0};
};
} // namespace distributed
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -0,0 +1,153 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/service/server.h"
namespace paddle {
namespace distributed {
class BrpcPsServer : public PSServer {
public:
BrpcPsServer() {}
virtual ~BrpcPsServer() {}
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t stop() {
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
virtual int32_t port();
private:
virtual int32_t initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class PsService;
typedef int32_t (PsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl);
class PsService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
int32_t initialize_shard_info();
int32_t pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
};
class DownpourPServerBrpcClosure : public PServerClosure {
public:
DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
: PServerClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id) { return 1; }
int check_save_response(size_t request_idx, int cmd_id) { return 1; }
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
} // namespace distributed
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -0,0 +1,86 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/port.h"
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope,
MultiVarMsg* var_msg, butil::IOBuf* iobuf);
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf);
// Deserialize for Server
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope);
// Deserialize for Client
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope);
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& iobuf,
const platform::DeviceContext& ctx);
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& iobuf,
const platform::DeviceContext& ctx);
} // namespace distributed
} // namespace paddle

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,19 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/env.h"
namespace paddle {
namespace distributed {} // namespace distributed
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -0,0 +1,168 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/heter_client.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/timer.h"
DECLARE_int32(rpc_deadline);
namespace paddle {
namespace distributed {
DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms");
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;
void HeterClient::MainThread() {
while (running_) {
RpcProfilerControl();
}
}
void HeterClient::Stop() {
running_ = false;
if (!is_initialized_) {
VLOG(0) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
auto status = StopHeterWorker();
status.wait();
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(1) << "HeterClient Stop Done";
}
}
void HeterClient::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
}
void HeterClient::CreateClient2XpuConnection() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = pserver_timeout_ms;
xpu_channels_.resize(xpu_list_.size());
for (size_t i = 0; i < xpu_list_.size(); ++i) {
xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterServer channel init fail";
}
}
}
void HeterClient::SendAndRecvAsync(
const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val;
// Todo: get correct channel
int num = trainer_id_ % xpu_channels_.size();
brpc::Controller cntl;
cntl.set_timeout_ms(pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
PADDLE_ENFORCE_NE(
cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
auto& response_io_buffer = cntl.response_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
ctx, p_scope);
}
std::future<int32_t> HeterClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
size_t request_call_num = xpu_channels_.size();
paddle::distributed::DownpourBrpcClosure* closure =
new paddle::distributed::DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(trainer_id_);
for (const auto& param : params) {
closure->request(i)->add_params(param);
}
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
} // end namespace distributed
} // end namespace paddle

@ -0,0 +1,127 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc;
class OnHeterRpcDone : public google::protobuf::Closure {
public:
OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
virtual ~OnHeterRpcDone() {}
void Run() {
std::unique_ptr<OnHeterRpcDone> self_guard(this);
handler_(this);
}
HeterRpcCallbackFunc handler_;
MultiVariableMessage response;
brpc::Controller cntl;
};
class HeterClient {
public:
virtual ~HeterClient() {}
HeterClient() {
running_ = true;
main_thread_.reset(
new std::thread(std::bind(&HeterClient::MainThread, this)));
}
void CreateClient2XpuConnection();
void SendAndRecvAsync(const std::vector<std::string>& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name);
// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint, const int& trainer_id) {
if (NULL == s_instance_) {
is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient());
std::vector<std::string> xpu_list = {endpoint};
s_instance_->SetXpuList(endpoint);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
}
return s_instance_;
}
void Stop();
void MainThread();
void RpcProfilerControl();
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string>& params);
std::future<int32_t> StartProfiler();
std::future<int32_t> StopProfiler();
std::future<int32_t> StopHeterWorker();
std::vector<std::string>& GetXpuList() { return xpu_list_; }
void SetXpuList(const std::vector<std::string>& xpu_list) {
xpu_list_ = xpu_list;
};
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
private:
static std::shared_ptr<HeterClient> s_instance_;
protected:
static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
bool running_ = false;
int trainer_id_;
bool do_server_profiler_ = false;
};
} // end namespace distributed
} // end namespace paddle

@ -0,0 +1,91 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/heter_server.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace distributed {
std::shared_ptr<HeterServer> HeterServer::s_instance_ = NULL;
void HeterServer::RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
service_.RegisterServiceHandler(message_name, func);
}
void HeterServer::StartHeterService() {
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "heter server start fail";
} else {
VLOG(0) << "heter server start success! listen on " << endpoint_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
}
condition_ready_.notify_all();
server_.Join();
}
void HeterServer::SetEndPoint(std::string& endpoint) {
endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}
void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
}
int32_t HeterService::stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
return 0;
}
int32_t HeterService::start_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
return 0;
}
int32_t HeterService::stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
}
return 0;
}
} // end namespace distributed
} // end namespace paddle

@ -0,0 +1,243 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
class HeterService;
typedef int32_t (HeterService::*serviceHandlerFunc)(
const PsRequestMessage& request, PsResponseMessage& response,
brpc::Controller* cntl);
typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
HeterServiceHandler;
class HeterService : public ::paddle::PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
_service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler;
}
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const ::paddle::PsRequestMessage* request,
::paddle::PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(*request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
};
void SendAndRecvVariable(::google::protobuf::RpcController* controller,
const MultiVarMsg* request, MultiVarMsg* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string message_name = request->message_name();
auto itr = handler_map_.find(message_name);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
PADDLE_ENFORCE_NE(
itr, handler_map_.end(),
platform::errors::InvalidArgument(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_",
message_name));
itr->second(request, response, cntl);
}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
handler_map_[message_name] = func;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
bool IsExit() { return is_exit_; }
private:
int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, brpc::Controller* cntl);
int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, brpc::Controller* cntl);
int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl);
private:
std::string endpoint_;
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_set<int> stop_cpu_worker_set_;
int fan_in_;
bool is_exit_ = false;
};
class HeterServer {
public:
virtual ~HeterServer() {}
void Stop() {
server_.Stop(1000);
server_.Join();
}
bool IsExit() { return service_.IsExit(); }
HeterServer() {}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func);
void StartHeterService();
void SetEndPoint(std::string& endpoint);
void SetFanin(int& fan_in);
// HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new HeterServer());
}
return s_instance_;
}
void WaitServerReady();
private:
static std::shared_ptr<HeterServer> s_instance_;
std::string endpoint_;
protected:
brpc::Server server_;
HeterService service_;
DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
class HeterRequestHandler {
public:
HeterRequestHandler()
: dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr) {}
virtual ~HeterRequestHandler() {}
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
message_to_prepared_ctx_ = g;
}
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
message_to_prepared_ctx_;
};
class RequestSendAndRecvHandler final : public HeterRequestHandler {
public:
RequestSendAndRecvHandler() {}
virtual ~RequestSendAndRecvHandler() {}
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) override {
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
auto& local_scope = scope_->NewScope();
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, &local_scope);
executor_->RunPreparedContext(
(*message_to_prepared_ctx_)[message_name].get(), &local_scope, false);
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
response_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name, response_var_names, empty_var_names, *dev_ctx_,
&local_scope, response, &response_io_buffer);
scope_->DeleteScope(&local_scope);
return 0;
}
};
} // end namespace distributed
} // end namespace paddle

@ -0,0 +1,89 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h"
#include <map>
#include "brpc/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSClient, BrpcPsClient);
int32_t PSClient::configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env, size_t client_id) {
_env = &env;
_config = config;
_dense_pull_regions = regions;
_client_id = client_id;
_config.mutable_worker_param()
->mutable_downpour_worker_param()
->mutable_downpour_table_param()
->CopyFrom(_config.server_param()
.downpour_server_param()
.downpour_table_param());
const auto &work_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor());
accessor->initialize();
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
return initialize();
}
PSClient *PSClientFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_client_class()) {
LOG(ERROR) << "miss client_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client = CREATE_CLASS(PSClient, service_param.client_class());
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
return NULL;
}
TableManager::instance().initialize();
LOG(INFO) << "Create PSClient[" << service_param.client_class()
<< "] success";
return client;
}
} // namespace distributed
} // namespace paddle

@ -0,0 +1,208 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
namespace paddle {
namespace distributed {
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
PSClientClosure(PSClientCallBack callback) : _callback(callback) {}
virtual ~PSClientClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
_promises.push_back(promise);
}
protected:
PSClientCallBack _callback;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PSClient {
public:
PSClient() {}
virtual ~PSClient() {}
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, size_t client_id) final;
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) = 0;
// 触发table数据退场
virtual std::future<int32_t> shrink(uint32_t table_id) = 0;
// 全量table进行数据load
virtual std::future<int32_t> load(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
// 全量table数据save value_accessor根据mode可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
//清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
// pull dense的参数部分并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; //保留
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求结果填充values
// keys和values的个数均为num个每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys聚集并分散发送到server
// 返回结果后遍历buffer并对values赋值
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) = 0;
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> flush() = 0;
// server优雅退出
virtual std::future<int32_t> stop_server() = 0;
// server profilera
virtual std::future<int32_t> start_profiler() = 0;
virtual std::future<int32_t> stop_profiler() = 0;
virtual std::future<int32_t> barrier(size_t table_id,
uint32_t barrier_type) = 0;
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual void finalize_worker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type,
int to_client_id,
const std::string &msg) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
// client2client消息处理std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_client2client_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int handle_client2client_msg(int msg_type, int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
return -1;
}
return itr->second(msg_type, from_client_id, msg);
}
virtual ValueAccessor *table_accessor(size_t table_id) {
auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) {
return NULL;
}
return itr->second.get();
}
virtual size_t get_server_nums() = 0;
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) = 0;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
protected:
virtual int32_t initialize() = 0;
size_t _client_id;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
_dense_pull_regions;
PSEnvironment *_env;
std::unordered_map<uint32_t, std::shared_ptr<ValueAccessor>> _table_accessors;
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; //处理client2client消息
};
REGISTER_REGISTERER(PSClient);
class PSClientFactory {
public:
static PSClient *create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle

@ -0,0 +1,113 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle;
option cc_generic_services = true;
option cc_enable_arenas = true;
enum PsCmdID {
PS_PULL_DENSE_TABLE = 0;
PS_PUSH_DENSE_TABLE = 1;
PS_PULL_SPARSE_TABLE = 2;
PS_PUSH_SPARSE_TABLE = 3;
PS_SHRINK_TABLE = 4;
PS_SAVE_ONE_TABLE = 5;
PS_SAVE_ALL_TABLE = 6;
PS_LOAD_ONE_TABLE = 7;
PS_LOAD_ALL_TABLE = 8;
PS_CLEAR_ONE_TABLE = 9;
PS_CLEAR_ALL_TABLE = 10;
PS_PUSH_DENSE_PARAM = 11;
PS_STOP_SERVER = 12;
PS_SAVE_ONE_CACHE_TABLE = 13;
PS_GET_CACHE_THRESHOLD = 14;
PS_CACHE_SHUFFLE = 15;
PS_COPY_TABLE = 16;
PS_COPY_TABLE_BY_FEASIGN = 17;
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
PS_PRINT_TABLE_STAT = 20;
PS_SAVE_ONE_TABLE_PREFIX = 21;
PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22;
PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23;
PS_PULL_GEO_PARAM = 24;
PS_BARRIER = 25;
PS_PUSH_SPARSE_PARAM = 26;
PS_START_PROFILER = 27;
PS_STOP_PROFILER = 28;
}
message PsRequestMessage {
required uint32 cmd_id = 1;
optional uint32 table_id = 2;
repeated bytes params = 3;
optional int32 client_id = 4;
optional bytes data = 5;
};
message PsResponseMessage {
required int32 err_code = 1 [ default = 0 ];
required string err_msg = 2 [ default = "" ];
optional bytes data = 3;
};
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
}
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
optional string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
optional VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
optional Type data_type = 3;
repeated int64 dims = 4;
// lod details:
optional int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
optional int64 slr_height = 7;
// tensor data
optional bytes data = 8;
}
// for SendAndRecv RPC method
message MultiVariableMessage {
// message flags
required string message_name = 1;
repeated string send_var_names = 2;
repeated string recv_var_names = 3;
repeated VariableMessage var_messages = 4;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
};

@ -0,0 +1,87 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSServer, BrpcPsServer);
REGISTER_CLASS(PsBaseService, PsService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_server_class()) {
LOG(ERROR) << "miss server_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSServer *server = CREATE_CLASS(PSServer, service_param.server_class());
if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class();
return NULL;
}
TableManager::instance().initialize();
return server;
}
int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
size_t server_rank) {
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
const auto &downpour_param = _config.downpour_server_param();
uint32_t barrier_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() ==
"BarrierTable") {
barrier_table = downpour_param.downpour_table_param(i).table_id();
}
table->initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map);
}
return initialize();
}
} // namespace distributed
} // namespace paddle

@ -0,0 +1,150 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
namespace paddle {
namespace distributed {
class Table;
class PSServer {
public:
PSServer() {}
virtual ~PSServer() {}
PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete;
virtual int32_t configure(const PSParameter &config, PSEnvironment &env,
size_t server_rank) final;
// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;
virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0;
inline size_t rank() const { return _rank; }
inline PSEnvironment *environment() { return _environment; }
inline const ServerParameter *config() const { return &_config; }
inline Table *table(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
}
return NULL;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() {
return &_table_map;
}
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected:
virtual int32_t initialize() = 0;
protected:
size_t _rank;
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
};
REGISTER_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack;
class PServerClosure : public google::protobuf::Closure {
public:
PServerClosure(PServerCallBack callback) : _callback(callback) {}
virtual ~PServerClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
_promises.push_back(promise);
}
protected:
PServerCallBack _callback;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PsBaseService : public PsService {
public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {}
virtual int32_t configure(PSServer *server) {
_server = server;
_rank = _server->rank();
_config = _server->config();
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response, int err_code,
const char *err_msg) {
response.set_err_msg(err_msg);
response.set_err_code(err_code);
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
}
virtual int32_t initialize() = 0;
protected:
size_t _rank;
PSServer *_server;
const ServerParameter *_config;
};
REGISTER_REGISTERER(PsBaseService);
class PSServerFactory {
public:
static PSServer *create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle

@ -0,0 +1,129 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/service.h"
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <iostream>
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/string/string_helper.h"
using namespace std;
namespace paddle {
namespace distributed {
paddle::distributed::PSParameter load_from_prototxt(
const std::string& filename) {
paddle::distributed::PSParameter param;
int file_descriptor = open(filename.c_str(), O_RDONLY);
if (file_descriptor == -1) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
google::protobuf::io::FileInputStream fileInput(file_descriptor);
if (!google::protobuf::TextFormat::Parse(&fileInput, &param)) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
close(file_descriptor);
return param;
}
void PSCore::init_gflag(const std::string& gflags) {
LOG(INFO) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40");
flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950");
}
auto it = flags.begin();
flags.insert(it, "exe default");
char* flags_ptr[flags.size()];
for (size_t i = 0; i < flags.size(); ++i) {
flags_ptr[i] = (char*)(flags[i].c_str());
}
int params_cnt = flags.size();
char** params_ptr = &(flags_ptr[0]);
::google::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
int PSCore::init_server(const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param));
ret = _server_ptr->configure(_ps_param, _ps_env, index);
CHECK(ret == 0) << "failed to configure server";
return ret;
}
int PSCore::init_worker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list, int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
int ret = 0;
VLOG(1) << "PSCore::init_worker";
auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env,
index);
communicator->Start();
return ret;
}
std::vector<uint64_t> PSCore::get_client_info() {
return _ps_env.get_client_info();
}
int PSCore::create_client2client_connection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
int ret = _worker_ptr->create_client2client_connection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret;
}
uint64_t PSCore::run_server(const std::string& ip, uint32_t port) {
return _server_ptr->start(ip, port);
}
int PSCore::finalize_worker() {
_worker_ptr->finalize_worker();
return 0;
}
int PSCore::stop_server() {
auto stop_status = _worker_ptr->stop_server();
stop_status.wait();
return 0;
}
paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; }
} // namespace distributed
} // namespace paddle

@ -0,0 +1,64 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <glog/logging.h>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/server.h"
namespace paddle {
namespace distributed {
class PSCore {
public:
explicit PSCore() {}
virtual ~PSCore() {}
virtual int init_server(const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num, int index);
virtual int init_worker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions,
const std::vector<std::string>* host_sign_list, int node_num, int index);
virtual uint64_t run_server(const std::string& ip, uint32_t port);
virtual int stop_server();
virtual int finalize_worker();
virtual std::vector<uint64_t> get_client_info();
virtual int create_client2client_connection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* get_param();
private:
void init_gflag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
};
} // namespace distributed
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save