add service (#29560)
* add service, remove ut on mac * fix heter_profiler & add heter stop method * fix code stylerevert-31562-mean
parent
c0163837a5
commit
0034273b7e
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,246 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <ThreadPool.h>
|
||||
#include "paddle/fluid/distributed/communicator_common.h"
|
||||
#include "paddle/fluid/distributed/service/service.h"
|
||||
#include "paddle/fluid/framework/archive.h"
|
||||
#include "paddle/fluid/framework/io/fs.h"
|
||||
#include "paddle/fluid/framework/io/shell.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
using framework::LoDTensor;
|
||||
using framework::Scope;
|
||||
using framework::SelectedRows;
|
||||
using framework::Variable;
|
||||
|
||||
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
|
||||
|
||||
class FleetWrapper {
|
||||
public:
|
||||
virtual ~FleetWrapper() {}
|
||||
FleetWrapper() {
|
||||
scale_sparse_gradient_with_batch_size_ = true;
|
||||
// trainer sleep some time for pserver core dump
|
||||
sleep_seconds_before_fail_exit_ = 300;
|
||||
// pserver request server timeout ms
|
||||
client2client_request_timeout_ms_ = 500000;
|
||||
// pserver connect server timeout_ms
|
||||
client2client_connect_timeout_ms_ = 10000;
|
||||
// pserver request max retry
|
||||
client2client_max_retry_ = 3;
|
||||
}
|
||||
|
||||
// set client to client communication config
|
||||
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
|
||||
int max_retry);
|
||||
|
||||
// Pull sparse variables from server in sync mode
|
||||
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
|
||||
// Param<out>: fea_values
|
||||
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
|
||||
const std::vector<std::string>& var_names,
|
||||
std::vector<uint64_t>* fea_keys,
|
||||
std::vector<std::vector<float>>* fea_values,
|
||||
int fea_dim,
|
||||
const std::vector<std::string>& var_emb_names);
|
||||
|
||||
// Pull sparse variables from server in async mode
|
||||
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
|
||||
// Param<out>: fea_values std::future
|
||||
std::future<int32_t> PullSparseVarsAsync(
|
||||
const Scope& scope, const uint64_t table_id,
|
||||
const std::vector<std::string>& var_names,
|
||||
std::vector<uint64_t>* fea_keys,
|
||||
std::vector<std::vector<float>>* fea_values, int fea_dim);
|
||||
|
||||
// Pull sparse variables from server in sync mode
|
||||
// pull immediately to tensors
|
||||
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
|
||||
uint64_t padding_id, platform::Place place,
|
||||
std::vector<const LoDTensor*>* inputs, // NOLINT
|
||||
std::vector<LoDTensor*>* outputs); // NOLINT
|
||||
|
||||
// pull dense variables from server in sync mod
|
||||
// Param<in>: scope, table_id, var_names
|
||||
// Param<out>: void
|
||||
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
|
||||
const std::vector<std::string>& var_names);
|
||||
|
||||
// pull dense variables from server in async mod
|
||||
// Param<in>: scope, table_id, var_names
|
||||
// Param<out>: pull_dense_status
|
||||
void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id,
|
||||
const std::vector<std::string>& var_names,
|
||||
std::vector<std::future<int32_t>>* pull_dense_status,
|
||||
bool in_cpu);
|
||||
|
||||
// push dense parameters(not gradients) to server in sync mode
|
||||
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
|
||||
const std::vector<std::string>& var_names);
|
||||
|
||||
void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id,
|
||||
const std::vector<std::string>& var_names,
|
||||
std::vector<std::future<int32_t>>* push_sparse_status,
|
||||
float scale_datanorm, int batch_size);
|
||||
|
||||
// push dense variables to server in sync mode
|
||||
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
|
||||
const std::vector<std::string>& var_names);
|
||||
|
||||
void PushSparseVarsAsync(
|
||||
const Scope& scope, const uint64_t table_id, const std::string& grad,
|
||||
std::vector<std::future<int32_t>>* push_sparse_status);
|
||||
// This is specially designed for click/show stats in server
|
||||
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
|
||||
// sparse_grad_names, batch_size, use_cvm, dump_slot
|
||||
// Param<out>: push_values, push_sparse_status
|
||||
void PushSparseVarsWithLabelAsync(
|
||||
const Scope& scope, const uint64_t table_id,
|
||||
const std::vector<uint64_t>& fea_keys,
|
||||
const std::vector<float>& fea_labels,
|
||||
const std::vector<std::string>& sparse_key_names,
|
||||
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
|
||||
std::vector<std::vector<float>>* push_values,
|
||||
std::vector<std::future<int32_t>>* push_sparse_status,
|
||||
const int batch_size, const bool use_cvm, const bool dump_slot,
|
||||
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
|
||||
|
||||
// Push sparse variables to server in async mode
|
||||
void PushSparseFromTensorWithLabelAsync(
|
||||
const Scope& scope, const uint64_t table_id, int fea_dim,
|
||||
uint64_t padding_id, bool scale_sparse, const std::string& accesor,
|
||||
const std::string& click_name, platform::Place place,
|
||||
const std::vector<std::string>& input_names,
|
||||
std::vector<const LoDTensor*>* inputs, // NOLINT
|
||||
std::vector<const LoDTensor*>* outputs); // NOLINT
|
||||
|
||||
// Push sparse variables to server in Async mode
|
||||
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
|
||||
// Param<Out>: push_values, push_sparse_status
|
||||
|
||||
// init server
|
||||
void LoadSparseOnServer(const std::string& path, const std::string& meta,
|
||||
uint32_t table_id);
|
||||
// init server
|
||||
// void InitServer(const std::string& dist_desc,
|
||||
// const std::vector<uint64_t>& host_sign_list, int index);
|
||||
void InitServer(const std::string& dist_desc,
|
||||
const std::vector<std::string>& host_sign_list, int index);
|
||||
// init trainer
|
||||
void InitWorker(const std::string& dist_desc,
|
||||
const std::vector<std::string>& host_sign_list, Scope* scope,
|
||||
const RpcCtxMap& send_ctx,
|
||||
const std::unordered_map<uint64_t, std::vector<std::string>>&
|
||||
dense_varnames,
|
||||
const std::map<std::string, std::string>& envs, int node_num,
|
||||
int index);
|
||||
|
||||
// stop server
|
||||
void StopServer();
|
||||
// finalize worker to make worker can be stop
|
||||
void FinalizeWorker();
|
||||
// run server with ip port
|
||||
uint64_t RunServer(const std::string& ip, uint32_t port);
|
||||
// get client info
|
||||
std::vector<uint64_t> GetClientsInfo();
|
||||
// create client to client connection
|
||||
void CreateClient2ClientConnection();
|
||||
// flush all push requests
|
||||
void ClientFlush();
|
||||
|
||||
// barrier with barrier table
|
||||
void BarrierWithTable(uint32_t barrier_type);
|
||||
|
||||
void PrintTableStat(const uint64_t table_id);
|
||||
// mode = 0, load all feature
|
||||
// mode = 1, load delta feature, which means load diff
|
||||
void LoadModel(const std::string& path, const int mode);
|
||||
// mode = 0, load all feature
|
||||
// mode = 1, load delta feature, which means load diff
|
||||
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
|
||||
const int mode);
|
||||
// mode = 0, save all feature
|
||||
// mode = 1, save delta feature, which means save diff
|
||||
void SaveModel(const std::string& path, const int mode);
|
||||
// mode = 0, save all feature
|
||||
// mode = 1, save delta feature, which means save diff
|
||||
void SaveModelOneTable(const uint64_t table_id, const std::string& path,
|
||||
const int mode);
|
||||
// clear all models, release their memory
|
||||
void ClearModel();
|
||||
// clear one table
|
||||
void ClearOneTable(const uint64_t table_id);
|
||||
// shrink sparse table
|
||||
void ShrinkSparseTable(int table_id);
|
||||
// shrink dense table
|
||||
void ShrinkDenseTable(int table_id, Scope* scope,
|
||||
std::vector<std::string> var_list, float decay,
|
||||
int emb_dim);
|
||||
|
||||
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
|
||||
// register client to client communication
|
||||
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
|
||||
// send client to client message
|
||||
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
|
||||
const std::string& msg);
|
||||
|
||||
// FleetWrapper singleton
|
||||
static std::shared_ptr<FleetWrapper> GetInstance() {
|
||||
if (NULL == s_instance_) {
|
||||
s_instance_.reset(new paddle::distributed::FleetWrapper());
|
||||
}
|
||||
return s_instance_;
|
||||
}
|
||||
// this performs better than rand_r, especially large data
|
||||
std::default_random_engine& LocalRandomEngine();
|
||||
|
||||
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
|
||||
|
||||
private:
|
||||
static std::shared_ptr<FleetWrapper> s_instance_;
|
||||
size_t GetAbsoluteSum(size_t start, size_t end, size_t level,
|
||||
const framework::LoD& lod);
|
||||
|
||||
protected:
|
||||
static bool is_initialized_;
|
||||
std::map<uint64_t, std::vector<paddle::distributed::Region>> _regions;
|
||||
bool scale_sparse_gradient_with_batch_size_;
|
||||
int32_t sleep_seconds_before_fail_exit_;
|
||||
int client2client_request_timeout_ms_;
|
||||
int client2client_connect_timeout_ms_;
|
||||
int client2client_max_retry_;
|
||||
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
|
||||
};
|
||||
|
||||
} // end namespace distributed
|
||||
} // end namespace paddle
|
@ -0,0 +1,40 @@
|
||||
set(BRPC_SRCS ps_client.cc server.cc)
|
||||
set_source_files_properties(${BRPC_SRCS})
|
||||
|
||||
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
|
||||
|
||||
brpc_library(sendrecv_rpc SRCS
|
||||
${BRPC_SRCS}
|
||||
PROTO sendrecv.proto
|
||||
DEPS ${BRPC_DEPS} )
|
||||
|
||||
set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
|
||||
|
||||
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
|
||||
|
||||
set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
|
||||
set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
|
||||
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
|
||||
|
||||
cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS})
|
||||
cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS})
|
||||
|
||||
cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
|
||||
cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
|
||||
|
||||
cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
|
||||
cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS})
|
||||
|
||||
cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS})
|
||||
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
|
||||
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,212 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "brpc/channel.h"
|
||||
#include "brpc/controller.h"
|
||||
#include "brpc/server.h"
|
||||
#include "paddle/fluid/distributed/service/ps_client.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
class DownpourPsClientService : public PsService {
|
||||
public:
|
||||
DownpourPsClientService() {}
|
||||
virtual ~DownpourPsClientService() {}
|
||||
|
||||
virtual int32_t configure(PSClient *client, size_t rank_id) {
|
||||
_client = client;
|
||||
_rank = rank_id;
|
||||
return 0;
|
||||
}
|
||||
virtual void service(::google::protobuf::RpcController *controller,
|
||||
const ::paddle::PsRequestMessage *request,
|
||||
::paddle::PsResponseMessage *response,
|
||||
::google::protobuf::Closure *done) override;
|
||||
|
||||
protected:
|
||||
size_t _rank;
|
||||
PSClient *_client;
|
||||
};
|
||||
|
||||
class DownpourBrpcClosure : public PSClientClosure {
|
||||
public:
|
||||
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
|
||||
: PSClientClosure(callback) {
|
||||
_waiting_num = num;
|
||||
|
||||
_cntls.resize(num);
|
||||
_requests.resize(num);
|
||||
_responses.resize(num);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
_cntls[i].reset(new brpc::Controller());
|
||||
}
|
||||
}
|
||||
virtual ~DownpourBrpcClosure() {}
|
||||
virtual void Run() override {
|
||||
if (_waiting_num.fetch_sub(1) == 1) {
|
||||
_callback(this);
|
||||
delete this;
|
||||
}
|
||||
}
|
||||
PsRequestMessage *request(size_t i) { return &_requests[i]; }
|
||||
PsResponseMessage *response(size_t i) { return &_responses[i]; }
|
||||
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
|
||||
int check_response(size_t request_idx, int cmd_id);
|
||||
int check_save_response(size_t request_idx, int cmd_id);
|
||||
std::string get_response(size_t request_idx, int cmd_id);
|
||||
|
||||
private:
|
||||
std::atomic<int32_t> _waiting_num;
|
||||
std::vector<PsRequestMessage> _requests;
|
||||
std::vector<PsResponseMessage> _responses;
|
||||
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct array_deleter {
|
||||
void operator()(T *&x) const { delete[] x; }
|
||||
};
|
||||
|
||||
class BrpcPsClient : public PSClient {
|
||||
public:
|
||||
BrpcPsClient() {}
|
||||
virtual ~BrpcPsClient() {
|
||||
// _running = false;
|
||||
// try {
|
||||
// _async_push_dense_thread.join();
|
||||
// _async_push_sparse_thread.join();
|
||||
//} catch (...) {
|
||||
//}
|
||||
}
|
||||
virtual int32_t create_client2client_connection(
|
||||
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
|
||||
virtual std::future<int32_t> shrink(uint32_t table_id) override;
|
||||
virtual std::future<int32_t> load(const std::string &epoch,
|
||||
const std::string &mode) override;
|
||||
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
|
||||
const std::string &mode) override;
|
||||
|
||||
virtual std::future<int32_t> save(const std::string &epoch,
|
||||
const std::string &mode) override;
|
||||
|
||||
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
|
||||
const std::string &mode) override;
|
||||
|
||||
virtual std::future<int32_t> clear() override;
|
||||
|
||||
virtual std::future<int32_t> clear(uint32_t table_id) override;
|
||||
|
||||
virtual std::future<int32_t> stop_server() override;
|
||||
|
||||
virtual std::future<int32_t> start_profiler() override;
|
||||
virtual std::future<int32_t> stop_profiler() override;
|
||||
|
||||
virtual void finalize_worker() override;
|
||||
|
||||
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
|
||||
size_t table_id);
|
||||
|
||||
virtual std::future<int32_t> push_dense_param(const Region *regions,
|
||||
size_t region_num,
|
||||
size_t table_id);
|
||||
|
||||
virtual std::future<int32_t> pull_sparse(float **select_values,
|
||||
size_t table_id,
|
||||
const uint64_t *keys, size_t num);
|
||||
|
||||
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
|
||||
|
||||
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
|
||||
|
||||
virtual std::future<int32_t> pull_geo_param(size_t table_id,
|
||||
std::vector<float> *values,
|
||||
std::vector<uint64_t> *keys,
|
||||
int pserver_idx);
|
||||
|
||||
virtual std::future<int32_t> flush();
|
||||
|
||||
virtual std::future<int32_t> send_client2client_msg(
|
||||
int msg_type, int to_client_id, const std::string &msg) override;
|
||||
|
||||
private:
|
||||
virtual int32_t initialize() override;
|
||||
|
||||
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
|
||||
uint32_t shard_num) {
|
||||
return dense_dim_total / shard_num + 1;
|
||||
}
|
||||
|
||||
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
|
||||
const std::vector<std::string> ¶m);
|
||||
|
||||
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
|
||||
const std::vector<std::string> ¶m);
|
||||
|
||||
inline brpc::Channel *get_sparse_channel(size_t server_id) {
|
||||
return _server_channels[server_id][0].get();
|
||||
}
|
||||
inline brpc::Channel *get_dense_channel(size_t server_id) {
|
||||
return _server_channels[server_id][1].get();
|
||||
}
|
||||
inline brpc::Channel *get_cmd_channel(size_t server_id) {
|
||||
return _server_channels[server_id][2].get();
|
||||
}
|
||||
|
||||
bool _running = false;
|
||||
bool _flushing = false;
|
||||
std::atomic<uint32_t> _async_call_num; //异步请求计数
|
||||
|
||||
std::vector<std::shared_ptr<brpc::Channel>>
|
||||
_client_channels; // client2client
|
||||
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
|
||||
_server_channels; // client2server
|
||||
virtual std::future<int32_t> push_dense_raw_gradient(
|
||||
int table_id, float *total_send_data, size_t total_send_data_size,
|
||||
void *done) override;
|
||||
|
||||
virtual std::future<int32_t> push_sparse_raw_gradient(
|
||||
size_t table_id, const uint64_t *keys, const float **update_values,
|
||||
size_t num, void *done) override;
|
||||
|
||||
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
|
||||
size_t table_id, const uint64_t *keys, const float **update_values,
|
||||
uint32_t num, void *done, int pserver_idx) override;
|
||||
|
||||
virtual std::future<int32_t> push_sparse_param(size_t table_id,
|
||||
const uint64_t *keys,
|
||||
const float **update_values,
|
||||
size_t num,
|
||||
void *done) override;
|
||||
|
||||
virtual size_t get_server_nums() { return _server_channels.size(); }
|
||||
|
||||
private:
|
||||
int32_t start_client_service();
|
||||
|
||||
float _mae = 0;
|
||||
float _mse = 0;
|
||||
uint16_t _push_times = 0;
|
||||
brpc::Server _server;
|
||||
DownpourPsClientService _service;
|
||||
std::atomic_uint grad_num_{0};
|
||||
};
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,153 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "brpc/channel.h"
|
||||
#include "brpc/controller.h"
|
||||
#include "brpc/server.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/distributed/service/server.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
class BrpcPsServer : public PSServer {
|
||||
public:
|
||||
BrpcPsServer() {}
|
||||
virtual ~BrpcPsServer() {}
|
||||
virtual uint64_t start(const std::string &ip, uint32_t port);
|
||||
virtual int32_t stop() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
stoped_ = true;
|
||||
cv_.notify_all();
|
||||
|
||||
_server.Stop(1000);
|
||||
_server.Join();
|
||||
return 0;
|
||||
}
|
||||
virtual int32_t port();
|
||||
|
||||
private:
|
||||
virtual int32_t initialize();
|
||||
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
bool stoped_ = false;
|
||||
brpc::Server _server;
|
||||
std::shared_ptr<PsBaseService> _service;
|
||||
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
|
||||
};
|
||||
|
||||
class PsService;
|
||||
|
||||
typedef int32_t (PsService::*serviceHandlerFunc)(
|
||||
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
|
||||
brpc::Controller *cntl);
|
||||
|
||||
class PsService : public PsBaseService {
|
||||
public:
|
||||
virtual int32_t initialize() override;
|
||||
|
||||
virtual void service(::google::protobuf::RpcController *controller,
|
||||
const ::paddle::PsRequestMessage *request,
|
||||
::paddle::PsResponseMessage *response,
|
||||
::google::protobuf::Closure *done) override;
|
||||
|
||||
private:
|
||||
int32_t initialize_shard_info();
|
||||
int32_t pull_dense(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t push_dense(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t push_dense_param(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response,
|
||||
brpc::Controller *cntl);
|
||||
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t barrier(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t push_sparse(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t load_one_table(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t load_all_table(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t save_one_table(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t save_all_table(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t shrink_table(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t clear_one_table(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t clear_all_table(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t stop_server(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t start_profiler(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
|
||||
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
|
||||
PsResponseMessage &response, brpc::Controller *cntl);
|
||||
|
||||
bool _is_initialize_shard_info;
|
||||
std::mutex _initialize_shard_mutex;
|
||||
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
|
||||
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
|
||||
std::vector<float> _ori_values;
|
||||
};
|
||||
|
||||
class DownpourPServerBrpcClosure : public PServerClosure {
|
||||
public:
|
||||
DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
|
||||
: PServerClosure(callback) {
|
||||
_waiting_num = num;
|
||||
_cntls.resize(num);
|
||||
_requests.resize(num);
|
||||
_responses.resize(num);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
_cntls[i].reset(new brpc::Controller());
|
||||
}
|
||||
}
|
||||
virtual ~DownpourPServerBrpcClosure() {}
|
||||
|
||||
virtual void Run() override {
|
||||
if (_waiting_num.fetch_sub(1) == 1) {
|
||||
_callback(this);
|
||||
delete this;
|
||||
}
|
||||
}
|
||||
PsRequestMessage *request(size_t i) { return &_requests[i]; }
|
||||
PsResponseMessage *response(size_t i) { return &_responses[i]; }
|
||||
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
|
||||
int check_response(size_t request_idx, int cmd_id) { return 1; }
|
||||
int check_save_response(size_t request_idx, int cmd_id) { return 1; }
|
||||
|
||||
private:
|
||||
std::atomic<int32_t> _waiting_num;
|
||||
std::vector<PsRequestMessage> _requests;
|
||||
std::vector<PsResponseMessage> _responses;
|
||||
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
|
||||
};
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,86 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "brpc/channel.h"
|
||||
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/platform/port.h"
|
||||
|
||||
namespace grpc {
|
||||
class ByteBuffer;
|
||||
} // namespace grpc
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class Scope;
|
||||
class Variable;
|
||||
} // namespace framework
|
||||
namespace platform {
|
||||
class DeviceContext;
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
using MultiVarMsg = ::paddle::MultiVariableMessage;
|
||||
using VarMsg = ::paddle::VariableMessage;
|
||||
|
||||
void SerializeToMultiVarMsgAndIOBuf(
|
||||
const std::string& message_name,
|
||||
const std::vector<std::string>& send_var_name_val,
|
||||
const std::vector<std::string>& recv_var_name_val,
|
||||
const platform::DeviceContext& ctx, const framework::Scope* scope,
|
||||
MultiVarMsg* var_msg, butil::IOBuf* iobuf);
|
||||
|
||||
void SerializeLodTensor(framework::Variable* var,
|
||||
const platform::DeviceContext& ctx, VarMsg* var_msg,
|
||||
butil::IOBuf* iobuf);
|
||||
|
||||
void SerializeSelectedRows(framework::Variable* var,
|
||||
const platform::DeviceContext& ctx, VarMsg* request,
|
||||
butil::IOBuf* iobuf);
|
||||
|
||||
// Deserialize for Server
|
||||
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
|
||||
const butil::IOBuf* iobuf,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::Scope* scope);
|
||||
|
||||
// Deserialize for Client
|
||||
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
|
||||
const butil::IOBuf* iobuf,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope* scope);
|
||||
|
||||
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
|
||||
butil::IOBufBytesIterator& iobuf,
|
||||
const platform::DeviceContext& ctx);
|
||||
|
||||
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
|
||||
butil::IOBufBytesIterator& iobuf,
|
||||
const platform::DeviceContext& ctx);
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,19 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/distributed/service/env.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {} // namespace distributed
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,168 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/distributed/service/heter_client.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
#include "paddle/fluid/framework/data_feed.h"
|
||||
#include "paddle/fluid/framework/device_worker.h"
|
||||
#include "paddle/fluid/framework/io/fs.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
#include "paddle/fluid/platform/timer.h"
|
||||
|
||||
DECLARE_int32(rpc_deadline);
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms");
|
||||
|
||||
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
|
||||
bool HeterClient::is_initialized_ = false;
|
||||
|
||||
void HeterClient::MainThread() {
|
||||
while (running_) {
|
||||
RpcProfilerControl();
|
||||
}
|
||||
}
|
||||
|
||||
void HeterClient::Stop() {
|
||||
running_ = false;
|
||||
if (!is_initialized_) {
|
||||
VLOG(0) << "HeterClient is not inited, do nothing";
|
||||
} else {
|
||||
if (main_thread_) {
|
||||
auto status = StopHeterWorker();
|
||||
status.wait();
|
||||
main_thread_->join();
|
||||
main_thread_.reset(nullptr);
|
||||
}
|
||||
VLOG(1) << "HeterClient Stop Done";
|
||||
}
|
||||
}
|
||||
|
||||
void HeterClient::RpcProfilerControl() {
|
||||
if (trainer_id_ == 0) {
|
||||
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
|
||||
// send profiler start flag
|
||||
do_server_profiler_ = true;
|
||||
auto start_status = StartProfiler();
|
||||
start_status.wait();
|
||||
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
|
||||
// send profiler end flag
|
||||
auto stop_status = StopProfiler();
|
||||
stop_status.wait();
|
||||
do_server_profiler_ = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HeterClient::CreateClient2XpuConnection() {
|
||||
brpc::ChannelOptions options;
|
||||
options.protocol = "baidu_std";
|
||||
options.connection_type = "single";
|
||||
options.timeout_ms = pserver_timeout_ms;
|
||||
|
||||
xpu_channels_.resize(xpu_list_.size());
|
||||
for (size_t i = 0; i < xpu_list_.size(); ++i) {
|
||||
xpu_channels_[i].reset(new brpc::Channel());
|
||||
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
|
||||
VLOG(0) << "HeterServer channel init fail";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HeterClient::SendAndRecvAsync(
|
||||
const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope, const std::string& message_name,
|
||||
const std::vector<std::string>& send_var_name,
|
||||
const std::vector<std::string>& recv_var_name) {
|
||||
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
|
||||
const platform::DeviceContext* p_ctx = &ctx;
|
||||
const framework::Scope* p_scope = &scope;
|
||||
const std::string message_name_val = message_name;
|
||||
const std::vector<std::string> send_var_name_val = send_var_name;
|
||||
const std::vector<std::string> recv_var_name_val = recv_var_name;
|
||||
|
||||
VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
|
||||
<< message_name_val;
|
||||
// Todo: get correct channel
|
||||
int num = trainer_id_ % xpu_channels_.size();
|
||||
|
||||
brpc::Controller cntl;
|
||||
cntl.set_timeout_ms(pserver_timeout_ms);
|
||||
distributed::MultiVarMsg request, response;
|
||||
auto& request_io_buffer = cntl.request_attachment();
|
||||
::paddle::PsService_Stub stub(xpu_channels_[num].get());
|
||||
distributed::SerializeToMultiVarMsgAndIOBuf(
|
||||
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
|
||||
&request, &request_io_buffer);
|
||||
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
|
||||
PADDLE_ENFORCE_NE(
|
||||
cntl.Failed(), true,
|
||||
platform::errors::Unimplemented(
|
||||
"HeterClient::SendAndRecv meets brpc error, error message is %s",
|
||||
cntl.ErrorText()));
|
||||
VLOG(4) << "call heter_worker success";
|
||||
auto& response_io_buffer = cntl.response_attachment();
|
||||
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
|
||||
ctx, p_scope);
|
||||
}
|
||||
|
||||
std::future<int32_t> HeterClient::SendCmd(
|
||||
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
|
||||
size_t request_call_num = xpu_channels_.size();
|
||||
paddle::distributed::DownpourBrpcClosure* closure =
|
||||
new paddle::distributed::DownpourBrpcClosure(
|
||||
request_call_num, [request_call_num, cmd_id](void* done) {
|
||||
int ret = 0;
|
||||
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
|
||||
for (size_t i = 0; i < request_call_num; ++i) {
|
||||
if (closure->check_response(i, cmd_id) != 0) {
|
||||
ret = -1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
closure->set_promise_value(ret);
|
||||
});
|
||||
auto promise = std::make_shared<std::promise<int32_t>>();
|
||||
closure->add_promise(promise);
|
||||
std::future<int> fut = promise->get_future();
|
||||
for (size_t i = 0; i < request_call_num; ++i) {
|
||||
closure->request(i)->set_cmd_id(cmd_id);
|
||||
closure->request(i)->set_table_id(table_id);
|
||||
closure->request(i)->set_client_id(trainer_id_);
|
||||
for (const auto& param : params) {
|
||||
closure->request(i)->add_params(param);
|
||||
}
|
||||
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
|
||||
closure->cntl(i)->set_timeout_ms(
|
||||
pserver_timeout_ms); // cmd msg don't limit timeout for save/load
|
||||
rpc_stub.service(closure->cntl(i), closure->request(i),
|
||||
closure->response(i), closure);
|
||||
}
|
||||
return fut;
|
||||
}
|
||||
|
||||
std::future<int32_t> HeterClient::StartProfiler() {
|
||||
return SendCmd(-1, PS_START_PROFILER, {});
|
||||
}
|
||||
|
||||
std::future<int32_t> HeterClient::StopProfiler() {
|
||||
return SendCmd(-1, PS_STOP_PROFILER, {});
|
||||
}
|
||||
|
||||
} // end namespace distributed
|
||||
} // end namespace paddle
|
@ -0,0 +1,127 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "brpc/channel.h"
|
||||
#include "brpc/controller.h"
|
||||
#include "brpc/server.h"
|
||||
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
|
||||
#include "paddle/fluid/distributed/service/brpc_utils.h"
|
||||
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
using MultiVarMsg = ::paddle::MultiVariableMessage;
|
||||
using VarMsg = ::paddle::VariableMessage;
|
||||
|
||||
typedef std::function<void(void*)> HeterRpcCallbackFunc;
|
||||
|
||||
class OnHeterRpcDone : public google::protobuf::Closure {
|
||||
public:
|
||||
OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
|
||||
virtual ~OnHeterRpcDone() {}
|
||||
void Run() {
|
||||
std::unique_ptr<OnHeterRpcDone> self_guard(this);
|
||||
handler_(this);
|
||||
}
|
||||
|
||||
HeterRpcCallbackFunc handler_;
|
||||
MultiVariableMessage response;
|
||||
brpc::Controller cntl;
|
||||
};
|
||||
|
||||
class HeterClient {
|
||||
public:
|
||||
virtual ~HeterClient() {}
|
||||
|
||||
HeterClient() {
|
||||
running_ = true;
|
||||
main_thread_.reset(
|
||||
new std::thread(std::bind(&HeterClient::MainThread, this)));
|
||||
}
|
||||
|
||||
void CreateClient2XpuConnection();
|
||||
|
||||
void SendAndRecvAsync(const std::vector<std::string>& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& message_name,
|
||||
const std::vector<std::string>& send_var_name,
|
||||
const std::vector<std::string>& recv_var_name);
|
||||
|
||||
// HeterClient singleton
|
||||
static std::shared_ptr<HeterClient> GetInstance(
|
||||
const std::vector<std::string>& endpoint, const int& trainer_id) {
|
||||
if (NULL == s_instance_) {
|
||||
is_initialized_ = true;
|
||||
s_instance_.reset(new paddle::distributed::HeterClient());
|
||||
std::vector<std::string> xpu_list = {endpoint};
|
||||
s_instance_->SetXpuList(endpoint);
|
||||
s_instance_->SetTrainerID(trainer_id);
|
||||
s_instance_->CreateClient2XpuConnection();
|
||||
}
|
||||
return s_instance_;
|
||||
}
|
||||
|
||||
void Stop();
|
||||
|
||||
void MainThread();
|
||||
|
||||
void RpcProfilerControl();
|
||||
|
||||
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
|
||||
const std::vector<std::string>& params);
|
||||
|
||||
std::future<int32_t> StartProfiler();
|
||||
std::future<int32_t> StopProfiler();
|
||||
std::future<int32_t> StopHeterWorker();
|
||||
|
||||
std::vector<std::string>& GetXpuList() { return xpu_list_; }
|
||||
|
||||
void SetXpuList(const std::vector<std::string>& xpu_list) {
|
||||
xpu_list_ = xpu_list;
|
||||
};
|
||||
|
||||
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
|
||||
|
||||
private:
|
||||
static std::shared_ptr<HeterClient> s_instance_;
|
||||
|
||||
protected:
|
||||
static bool is_initialized_;
|
||||
std::unique_ptr<std::thread> main_thread_{nullptr};
|
||||
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
|
||||
DISABLE_COPY_AND_ASSIGN(HeterClient);
|
||||
std::vector<std::string> xpu_list_;
|
||||
|
||||
bool running_ = false;
|
||||
int trainer_id_;
|
||||
bool do_server_profiler_ = false;
|
||||
};
|
||||
|
||||
} // end namespace distributed
|
||||
} // end namespace paddle
|
@ -0,0 +1,91 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/distributed/service/heter_server.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/timer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
std::shared_ptr<HeterServer> HeterServer::s_instance_ = NULL;
|
||||
|
||||
void HeterServer::RegisterServiceHandler(std::string message_name,
|
||||
HeterServiceHandler func) {
|
||||
service_.RegisterServiceHandler(message_name, func);
|
||||
}
|
||||
|
||||
void HeterServer::StartHeterService() {
|
||||
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
|
||||
brpc::ServerOptions options;
|
||||
if (server_.Start(endpoint_.c_str(), &options) != 0) {
|
||||
VLOG(0) << "heter server start fail";
|
||||
} else {
|
||||
VLOG(0) << "heter server start success! listen on " << endpoint_;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(this->mutex_ready_);
|
||||
ready_ = 1;
|
||||
}
|
||||
condition_ready_.notify_all();
|
||||
|
||||
server_.Join();
|
||||
}
|
||||
|
||||
void HeterServer::SetEndPoint(std::string& endpoint) {
|
||||
endpoint_ = endpoint;
|
||||
service_.SetEndpoint(endpoint);
|
||||
}
|
||||
|
||||
void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); }
|
||||
|
||||
void HeterServer::WaitServerReady() {
|
||||
std::unique_lock<std::mutex> lock(this->mutex_ready_);
|
||||
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
|
||||
}
|
||||
|
||||
int32_t HeterService::stop_profiler(const PsRequestMessage& request,
|
||||
PsResponseMessage& response,
|
||||
brpc::Controller* cntl) {
|
||||
platform::DisableProfiler(
|
||||
platform::EventSortingKey::kDefault,
|
||||
string::Sprintf("heter_worker_%s_profile", endpoint_));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t HeterService::start_profiler(const PsRequestMessage& request,
|
||||
PsResponseMessage& response,
|
||||
brpc::Controller* cntl) {
|
||||
platform::EnableProfiler(platform::ProfilerState::kAll);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t HeterService::stop_heter_worker(const PsRequestMessage& request,
|
||||
PsResponseMessage& response,
|
||||
brpc::Controller* cntl) {
|
||||
auto client_id = request.client_id();
|
||||
stop_cpu_worker_set_.insert(client_id);
|
||||
if (stop_cpu_worker_set_.size() == fan_in_) {
|
||||
is_exit_ = true;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // end namespace distributed
|
||||
} // end namespace paddle
|
@ -0,0 +1,243 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "brpc/channel.h"
|
||||
#include "brpc/controller.h"
|
||||
#include "brpc/server.h"
|
||||
#include "paddle/fluid/distributed/service/brpc_utils.h"
|
||||
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
using MultiVarMsg = ::paddle::MultiVariableMessage;
|
||||
using VarMsg = ::paddle::VariableMessage;
|
||||
|
||||
class HeterService;
|
||||
typedef int32_t (HeterService::*serviceHandlerFunc)(
|
||||
const PsRequestMessage& request, PsResponseMessage& response,
|
||||
brpc::Controller* cntl);
|
||||
|
||||
typedef std::function<void(void*)> HeterRpcCallbackFunc;
|
||||
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
|
||||
HeterServiceHandler;
|
||||
|
||||
class HeterService : public ::paddle::PsService {
|
||||
public:
|
||||
HeterService() {
|
||||
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
|
||||
_service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler;
|
||||
_service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler;
|
||||
}
|
||||
|
||||
virtual ~HeterService() {}
|
||||
|
||||
virtual void service(::google::protobuf::RpcController* controller,
|
||||
const ::paddle::PsRequestMessage* request,
|
||||
::paddle::PsResponseMessage* response,
|
||||
::google::protobuf::Closure* done) {
|
||||
brpc::ClosureGuard done_guard(done);
|
||||
std::string log_label("ReceiveCmd-");
|
||||
|
||||
response->set_err_code(0);
|
||||
response->set_err_msg("");
|
||||
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
|
||||
auto itr = _service_handler_map.find(request->cmd_id());
|
||||
if (itr == _service_handler_map.end()) {
|
||||
std::string err_msg(
|
||||
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
|
||||
err_msg.append(std::to_string(request->cmd_id()));
|
||||
return;
|
||||
}
|
||||
serviceHandlerFunc handler_func = itr->second;
|
||||
int service_ret = (this->*handler_func)(*request, *response, cntl);
|
||||
if (service_ret != 0) {
|
||||
response->set_err_code(service_ret);
|
||||
response->set_err_msg("server internal error");
|
||||
}
|
||||
};
|
||||
|
||||
void SendAndRecvVariable(::google::protobuf::RpcController* controller,
|
||||
const MultiVarMsg* request, MultiVarMsg* response,
|
||||
::google::protobuf::Closure* done) {
|
||||
brpc::ClosureGuard done_guard(done);
|
||||
std::string message_name = request->message_name();
|
||||
auto itr = handler_map_.find(message_name);
|
||||
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
|
||||
PADDLE_ENFORCE_NE(
|
||||
itr, handler_map_.end(),
|
||||
platform::errors::InvalidArgument(
|
||||
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
|
||||
"which is not in HeterService::handler_map_",
|
||||
message_name));
|
||||
itr->second(request, response, cntl);
|
||||
}
|
||||
|
||||
void RegisterServiceHandler(std::string message_name,
|
||||
HeterServiceHandler func) {
|
||||
handler_map_[message_name] = func;
|
||||
}
|
||||
|
||||
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
|
||||
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
|
||||
bool IsExit() { return is_exit_; }
|
||||
|
||||
private:
|
||||
int32_t stop_profiler(const PsRequestMessage& request,
|
||||
PsResponseMessage& response, brpc::Controller* cntl);
|
||||
|
||||
int32_t start_profiler(const PsRequestMessage& request,
|
||||
PsResponseMessage& response, brpc::Controller* cntl);
|
||||
|
||||
int32_t stop_heter_worker(const PsRequestMessage& request,
|
||||
PsResponseMessage& response,
|
||||
brpc::Controller* cntl);
|
||||
|
||||
private:
|
||||
std::string endpoint_;
|
||||
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
|
||||
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
|
||||
std::unordered_set<int> stop_cpu_worker_set_;
|
||||
int fan_in_;
|
||||
bool is_exit_ = false;
|
||||
};
|
||||
|
||||
class HeterServer {
|
||||
public:
|
||||
virtual ~HeterServer() {}
|
||||
|
||||
void Stop() {
|
||||
server_.Stop(1000);
|
||||
server_.Join();
|
||||
}
|
||||
|
||||
bool IsExit() { return service_.IsExit(); }
|
||||
|
||||
HeterServer() {}
|
||||
|
||||
void RegisterServiceHandler(std::string message_name,
|
||||
HeterServiceHandler func);
|
||||
|
||||
void StartHeterService();
|
||||
|
||||
void SetEndPoint(std::string& endpoint);
|
||||
void SetFanin(int& fan_in);
|
||||
|
||||
// HeterWrapper singleton
|
||||
static std::shared_ptr<HeterServer> GetInstance() {
|
||||
if (NULL == s_instance_) {
|
||||
s_instance_.reset(new HeterServer());
|
||||
}
|
||||
return s_instance_;
|
||||
}
|
||||
|
||||
void WaitServerReady();
|
||||
|
||||
private:
|
||||
static std::shared_ptr<HeterServer> s_instance_;
|
||||
std::string endpoint_;
|
||||
|
||||
protected:
|
||||
brpc::Server server_;
|
||||
HeterService service_;
|
||||
DISABLE_COPY_AND_ASSIGN(HeterServer);
|
||||
std::mutex mutex_ready_;
|
||||
std::condition_variable condition_ready_;
|
||||
int ready_;
|
||||
};
|
||||
|
||||
class HeterRequestHandler {
|
||||
public:
|
||||
HeterRequestHandler()
|
||||
: dev_ctx_(nullptr),
|
||||
executor_(nullptr),
|
||||
scope_(nullptr),
|
||||
program_(nullptr) {}
|
||||
|
||||
virtual ~HeterRequestHandler() {}
|
||||
|
||||
void SetScope(framework::Scope* scope) { scope_ = scope; }
|
||||
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
|
||||
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
|
||||
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
|
||||
|
||||
void SetGradToPreparedCtx(
|
||||
std::unordered_map<
|
||||
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
|
||||
message_to_prepared_ctx_ = g;
|
||||
}
|
||||
|
||||
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
|
||||
brpc::Controller* cntl) = 0;
|
||||
|
||||
protected:
|
||||
const platform::DeviceContext* dev_ctx_;
|
||||
framework::Executor* executor_;
|
||||
framework::Scope* scope_;
|
||||
framework::ProgramDesc* program_;
|
||||
|
||||
std::unordered_map<std::string,
|
||||
std::shared_ptr<framework::ExecutorPrepareContext>>*
|
||||
message_to_prepared_ctx_;
|
||||
};
|
||||
|
||||
class RequestSendAndRecvHandler final : public HeterRequestHandler {
|
||||
public:
|
||||
RequestSendAndRecvHandler() {}
|
||||
virtual ~RequestSendAndRecvHandler() {}
|
||||
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
|
||||
brpc::Controller* cntl) override {
|
||||
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
|
||||
auto& local_scope = scope_->NewScope();
|
||||
auto message_name = request->message_name();
|
||||
auto& request_io_buffer = cntl->request_attachment();
|
||||
distributed::DeserializeFromMultiVarMsgAndIOBuf(
|
||||
*request, &request_io_buffer, *dev_ctx_, &local_scope);
|
||||
executor_->RunPreparedContext(
|
||||
(*message_to_prepared_ctx_)[message_name].get(), &local_scope, false);
|
||||
|
||||
auto response_var_nums = request->recv_var_names_size();
|
||||
std::vector<std::string> response_var_names(response_var_nums),
|
||||
empty_var_names{};
|
||||
|
||||
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
|
||||
response_var_names[var_idx] = request->recv_var_names(var_idx);
|
||||
}
|
||||
auto& response_io_buffer = cntl->response_attachment();
|
||||
distributed::SerializeToMultiVarMsgAndIOBuf(
|
||||
message_name, response_var_names, empty_var_names, *dev_ctx_,
|
||||
&local_scope, response, &response_io_buffer);
|
||||
scope_->DeleteScope(&local_scope);
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace distributed
|
||||
} // end namespace paddle
|
@ -0,0 +1,89 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/distributed/service/ps_client.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "brpc/server.h"
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
|
||||
#include "paddle/fluid/distributed/table/table.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
REGISTER_CLASS(PSClient, BrpcPsClient);
|
||||
|
||||
int32_t PSClient::configure(
|
||||
const PSParameter &config,
|
||||
const std::map<uint64_t, std::vector<paddle::distributed::Region>> ®ions,
|
||||
PSEnvironment &env, size_t client_id) {
|
||||
_env = &env;
|
||||
_config = config;
|
||||
_dense_pull_regions = regions;
|
||||
_client_id = client_id;
|
||||
_config.mutable_worker_param()
|
||||
->mutable_downpour_worker_param()
|
||||
->mutable_downpour_table_param()
|
||||
->CopyFrom(_config.server_param()
|
||||
.downpour_server_param()
|
||||
.downpour_table_param());
|
||||
|
||||
const auto &work_param = _config.worker_param().downpour_worker_param();
|
||||
|
||||
for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
|
||||
auto *accessor = CREATE_CLASS(
|
||||
ValueAccessor,
|
||||
work_param.downpour_table_param(i).accessor().accessor_class());
|
||||
accessor->configure(work_param.downpour_table_param(i).accessor());
|
||||
accessor->initialize();
|
||||
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
|
||||
accessor);
|
||||
}
|
||||
return initialize();
|
||||
}
|
||||
|
||||
PSClient *PSClientFactory::create(const PSParameter &ps_config) {
|
||||
const auto &config = ps_config.server_param();
|
||||
if (!config.has_downpour_server_param()) {
|
||||
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!config.downpour_server_param().has_service_param()) {
|
||||
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!config.downpour_server_param().service_param().has_client_class()) {
|
||||
LOG(ERROR) << "miss client_class in "
|
||||
"ServerParameter.downpour_server_param.service_param";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
const auto &service_param = config.downpour_server_param().service_param();
|
||||
PSClient *client = CREATE_CLASS(PSClient, service_param.client_class());
|
||||
if (client == NULL) {
|
||||
LOG(ERROR) << "client is not registered, server_name:"
|
||||
<< service_param.client_class();
|
||||
return NULL;
|
||||
}
|
||||
|
||||
TableManager::instance().initialize();
|
||||
LOG(INFO) << "Create PSClient[" << service_param.client_class()
|
||||
<< "] success";
|
||||
return client;
|
||||
}
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
@ -0,0 +1,113 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
package paddle;
|
||||
option cc_generic_services = true;
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
enum PsCmdID {
|
||||
PS_PULL_DENSE_TABLE = 0;
|
||||
PS_PUSH_DENSE_TABLE = 1;
|
||||
PS_PULL_SPARSE_TABLE = 2;
|
||||
PS_PUSH_SPARSE_TABLE = 3;
|
||||
PS_SHRINK_TABLE = 4;
|
||||
PS_SAVE_ONE_TABLE = 5;
|
||||
PS_SAVE_ALL_TABLE = 6;
|
||||
PS_LOAD_ONE_TABLE = 7;
|
||||
PS_LOAD_ALL_TABLE = 8;
|
||||
PS_CLEAR_ONE_TABLE = 9;
|
||||
PS_CLEAR_ALL_TABLE = 10;
|
||||
PS_PUSH_DENSE_PARAM = 11;
|
||||
PS_STOP_SERVER = 12;
|
||||
PS_SAVE_ONE_CACHE_TABLE = 13;
|
||||
PS_GET_CACHE_THRESHOLD = 14;
|
||||
PS_CACHE_SHUFFLE = 15;
|
||||
PS_COPY_TABLE = 16;
|
||||
PS_COPY_TABLE_BY_FEASIGN = 17;
|
||||
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
|
||||
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
|
||||
PS_PRINT_TABLE_STAT = 20;
|
||||
PS_SAVE_ONE_TABLE_PREFIX = 21;
|
||||
PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22;
|
||||
PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23;
|
||||
PS_PULL_GEO_PARAM = 24;
|
||||
PS_BARRIER = 25;
|
||||
PS_PUSH_SPARSE_PARAM = 26;
|
||||
PS_START_PROFILER = 27;
|
||||
PS_STOP_PROFILER = 28;
|
||||
}
|
||||
|
||||
message PsRequestMessage {
|
||||
required uint32 cmd_id = 1;
|
||||
optional uint32 table_id = 2;
|
||||
repeated bytes params = 3;
|
||||
optional int32 client_id = 4;
|
||||
optional bytes data = 5;
|
||||
};
|
||||
|
||||
message PsResponseMessage {
|
||||
required int32 err_code = 1 [ default = 0 ];
|
||||
required string err_msg = 2 [ default = "" ];
|
||||
optional bytes data = 3;
|
||||
};
|
||||
|
||||
enum VarType {
|
||||
LOD_TENSOR = 0;
|
||||
SELECTED_ROWS = 1;
|
||||
}
|
||||
|
||||
message VariableMessage {
|
||||
enum Type {
|
||||
// Pod Types
|
||||
BOOL = 0;
|
||||
INT16 = 1;
|
||||
INT32 = 2;
|
||||
INT64 = 3;
|
||||
FP16 = 4;
|
||||
FP32 = 5;
|
||||
FP64 = 6;
|
||||
}
|
||||
|
||||
message LodData { repeated int64 lod_data = 1; }
|
||||
optional string varname = 1;
|
||||
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
|
||||
optional VarType type = 2;
|
||||
// bool persistable is not needed for sending.
|
||||
// tensor info:
|
||||
optional Type data_type = 3;
|
||||
repeated int64 dims = 4;
|
||||
|
||||
// lod details:
|
||||
optional int64 lod_level = 5;
|
||||
repeated LodData lod = 6;
|
||||
// selected_rows height, aka. original dim0
|
||||
optional int64 slr_height = 7;
|
||||
// tensor data
|
||||
optional bytes data = 8;
|
||||
}
|
||||
|
||||
// for SendAndRecv RPC method
|
||||
message MultiVariableMessage {
|
||||
// message flags
|
||||
required string message_name = 1;
|
||||
repeated string send_var_names = 2;
|
||||
repeated string recv_var_names = 3;
|
||||
repeated VariableMessage var_messages = 4;
|
||||
};
|
||||
|
||||
service PsService {
|
||||
rpc service(PsRequestMessage) returns (PsResponseMessage);
|
||||
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
|
||||
};
|
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/distributed/service/server.h"
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
|
||||
#include "paddle/fluid/distributed/table/table.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
REGISTER_CLASS(PSServer, BrpcPsServer);
|
||||
REGISTER_CLASS(PsBaseService, PsService);
|
||||
|
||||
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
|
||||
const auto &config = ps_config.server_param();
|
||||
|
||||
if (!config.has_downpour_server_param()) {
|
||||
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!config.downpour_server_param().has_service_param()) {
|
||||
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!config.downpour_server_param().service_param().has_server_class()) {
|
||||
LOG(ERROR) << "miss server_class in "
|
||||
"ServerParameter.downpour_server_param.service_param";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
const auto &service_param = config.downpour_server_param().service_param();
|
||||
PSServer *server = CREATE_CLASS(PSServer, service_param.server_class());
|
||||
if (server == NULL) {
|
||||
LOG(ERROR) << "server is not registered, server_name:"
|
||||
<< service_param.server_class();
|
||||
return NULL;
|
||||
}
|
||||
TableManager::instance().initialize();
|
||||
return server;
|
||||
}
|
||||
|
||||
int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
|
||||
size_t server_rank) {
|
||||
_config = config.server_param();
|
||||
_rank = server_rank;
|
||||
_environment = &env;
|
||||
_shuffled_ins =
|
||||
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
|
||||
const auto &downpour_param = _config.downpour_server_param();
|
||||
|
||||
uint32_t barrier_table = UINT32_MAX;
|
||||
|
||||
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
|
||||
auto *table = CREATE_CLASS(
|
||||
Table, downpour_param.downpour_table_param(i).table_class());
|
||||
|
||||
if (downpour_param.downpour_table_param(i).table_class() ==
|
||||
"BarrierTable") {
|
||||
barrier_table = downpour_param.downpour_table_param(i).table_id();
|
||||
}
|
||||
table->initialize(downpour_param.downpour_table_param(i),
|
||||
config.fs_client_param());
|
||||
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
|
||||
}
|
||||
|
||||
if (barrier_table != UINT32_MAX) {
|
||||
_table_map[barrier_table]->set_table_map(&_table_map);
|
||||
}
|
||||
|
||||
return initialize();
|
||||
}
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
@ -0,0 +1,150 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "butil/endpoint.h"
|
||||
#include "google/protobuf/service.h"
|
||||
#include "paddle/fluid/distributed/common/registerer.h"
|
||||
#include "paddle/fluid/distributed/ps.pb.h"
|
||||
#include "paddle/fluid/distributed/service/env.h"
|
||||
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
class Table;
|
||||
|
||||
class PSServer {
|
||||
public:
|
||||
PSServer() {}
|
||||
virtual ~PSServer() {}
|
||||
PSServer(PSServer &&) = delete;
|
||||
PSServer(const PSServer &) = delete;
|
||||
|
||||
virtual int32_t configure(const PSParameter &config, PSEnvironment &env,
|
||||
size_t server_rank) final;
|
||||
|
||||
// return server_ip
|
||||
virtual std::string ip() { return butil::my_ip_cstr(); }
|
||||
// return server_port
|
||||
virtual int32_t port() = 0;
|
||||
|
||||
virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
|
||||
virtual int32_t stop() = 0;
|
||||
|
||||
inline size_t rank() const { return _rank; }
|
||||
|
||||
inline PSEnvironment *environment() { return _environment; }
|
||||
|
||||
inline const ServerParameter *config() const { return &_config; }
|
||||
inline Table *table(size_t table_id) {
|
||||
auto itr = _table_map.find(table_id);
|
||||
if (itr != _table_map.end()) {
|
||||
return itr->second.get();
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() {
|
||||
return &_table_map;
|
||||
}
|
||||
|
||||
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
|
||||
virtual int registe_pserver2pserver_msg_handler(int msg_type,
|
||||
MsgHandlerFunc handler) {
|
||||
_msg_handler_map[msg_type] = handler;
|
||||
return 0;
|
||||
}
|
||||
|
||||
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
|
||||
|
||||
protected:
|
||||
virtual int32_t initialize() = 0;
|
||||
|
||||
protected:
|
||||
size_t _rank;
|
||||
ServerParameter _config;
|
||||
PSEnvironment *_environment;
|
||||
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
|
||||
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
|
||||
};
|
||||
|
||||
REGISTER_REGISTERER(PSServer);
|
||||
|
||||
typedef std::function<void(void *)> PServerCallBack;
|
||||
|
||||
class PServerClosure : public google::protobuf::Closure {
|
||||
public:
|
||||
PServerClosure(PServerCallBack callback) : _callback(callback) {}
|
||||
virtual ~PServerClosure() {}
|
||||
virtual void set_promise_value(int value) {
|
||||
for (auto &promise : _promises) {
|
||||
promise->set_value(value);
|
||||
}
|
||||
}
|
||||
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
|
||||
_promises.push_back(promise);
|
||||
}
|
||||
|
||||
protected:
|
||||
PServerCallBack _callback;
|
||||
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
|
||||
};
|
||||
|
||||
class PsBaseService : public PsService {
|
||||
public:
|
||||
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
|
||||
virtual ~PsBaseService() {}
|
||||
|
||||
virtual int32_t configure(PSServer *server) {
|
||||
_server = server;
|
||||
_rank = _server->rank();
|
||||
_config = _server->config();
|
||||
return 0;
|
||||
}
|
||||
virtual void service(::google::protobuf::RpcController *controller,
|
||||
const ::paddle::PsRequestMessage *request,
|
||||
::paddle::PsResponseMessage *response,
|
||||
::google::protobuf::Closure *done) override = 0;
|
||||
|
||||
virtual void set_response_code(PsResponseMessage &response, int err_code,
|
||||
const char *err_msg) {
|
||||
response.set_err_msg(err_msg);
|
||||
response.set_err_code(err_code);
|
||||
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
|
||||
}
|
||||
|
||||
virtual int32_t initialize() = 0;
|
||||
|
||||
protected:
|
||||
size_t _rank;
|
||||
PSServer *_server;
|
||||
const ServerParameter *_config;
|
||||
};
|
||||
REGISTER_REGISTERER(PsBaseService);
|
||||
|
||||
class PSServerFactory {
|
||||
public:
|
||||
static PSServer *create(const PSParameter &config);
|
||||
};
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
@ -0,0 +1,129 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/distributed/service/service.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <iostream>
|
||||
#include "paddle/fluid/distributed/service/communicator.h"
|
||||
#include "paddle/fluid/string/string_helper.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
paddle::distributed::PSParameter load_from_prototxt(
|
||||
const std::string& filename) {
|
||||
paddle::distributed::PSParameter param;
|
||||
int file_descriptor = open(filename.c_str(), O_RDONLY);
|
||||
|
||||
if (file_descriptor == -1) {
|
||||
VLOG(3) << "FATAL: fail to parse " << filename;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
google::protobuf::io::FileInputStream fileInput(file_descriptor);
|
||||
if (!google::protobuf::TextFormat::Parse(&fileInput, ¶m)) {
|
||||
VLOG(3) << "FATAL: fail to parse " << filename;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
close(file_descriptor);
|
||||
return param;
|
||||
}
|
||||
|
||||
void PSCore::init_gflag(const std::string& gflags) {
|
||||
LOG(INFO) << "Init With Gflags:" << gflags;
|
||||
std::vector<std::string> flags = paddle::string::split_string(gflags);
|
||||
if (flags.size() < 1) {
|
||||
flags.push_back("-max_body_size=314217728");
|
||||
flags.push_back("-bthread_concurrency=40");
|
||||
flags.push_back("-socket_max_unwritten_bytes=2048000000");
|
||||
flags.push_back("-max_connection_pool_size=1950");
|
||||
}
|
||||
auto it = flags.begin();
|
||||
flags.insert(it, "exe default");
|
||||
char* flags_ptr[flags.size()];
|
||||
for (size_t i = 0; i < flags.size(); ++i) {
|
||||
flags_ptr[i] = (char*)(flags[i].c_str());
|
||||
}
|
||||
int params_cnt = flags.size();
|
||||
char** params_ptr = &(flags_ptr[0]);
|
||||
::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true);
|
||||
}
|
||||
|
||||
int PSCore::init_server(const std::string& dist_desc,
|
||||
const std::vector<std::string>* host_sign_list,
|
||||
int node_num, int index) {
|
||||
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
|
||||
init_gflag(_ps_param.init_gflags());
|
||||
_ps_env = paddle::distributed::PaddlePSEnvironment();
|
||||
_ps_env.set_ps_servers(host_sign_list, node_num);
|
||||
int ret = 0;
|
||||
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
|
||||
paddle::distributed::PSServerFactory::create(_ps_param));
|
||||
ret = _server_ptr->configure(_ps_param, _ps_env, index);
|
||||
CHECK(ret == 0) << "failed to configure server";
|
||||
return ret;
|
||||
}
|
||||
|
||||
int PSCore::init_worker(
|
||||
const std::string& dist_desc,
|
||||
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
|
||||
const std::vector<std::string>* host_sign_list, int node_num, int index) {
|
||||
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
|
||||
init_gflag(_ps_param.init_gflags());
|
||||
_ps_env = paddle::distributed::PaddlePSEnvironment();
|
||||
_ps_env.set_ps_servers(host_sign_list, node_num);
|
||||
int ret = 0;
|
||||
VLOG(1) << "PSCore::init_worker";
|
||||
auto* communicator = Communicator::GetInstance();
|
||||
ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env,
|
||||
index);
|
||||
communicator->Start();
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> PSCore::get_client_info() {
|
||||
return _ps_env.get_client_info();
|
||||
}
|
||||
|
||||
int PSCore::create_client2client_connection(int pserver_timeout_ms,
|
||||
int pserver_connect_timeout_ms,
|
||||
int max_retry) {
|
||||
int ret = _worker_ptr->create_client2client_connection(
|
||||
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
|
||||
return ret;
|
||||
}
|
||||
|
||||
uint64_t PSCore::run_server(const std::string& ip, uint32_t port) {
|
||||
return _server_ptr->start(ip, port);
|
||||
}
|
||||
|
||||
int PSCore::finalize_worker() {
|
||||
_worker_ptr->finalize_worker();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int PSCore::stop_server() {
|
||||
auto stop_status = _worker_ptr->stop_server();
|
||||
stop_status.wait();
|
||||
return 0;
|
||||
}
|
||||
paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; }
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include "paddle/fluid/distributed/ps.pb.h"
|
||||
#include "paddle/fluid/distributed/service/ps_client.h"
|
||||
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
|
||||
#include "paddle/fluid/distributed/service/server.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace distributed {
|
||||
|
||||
class PSCore {
|
||||
public:
|
||||
explicit PSCore() {}
|
||||
virtual ~PSCore() {}
|
||||
|
||||
virtual int init_server(const std::string& dist_desc,
|
||||
const std::vector<std::string>* host_sign_list,
|
||||
int node_num, int index);
|
||||
virtual int init_worker(
|
||||
const std::string& dist_desc,
|
||||
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
|
||||
regions,
|
||||
const std::vector<std::string>* host_sign_list, int node_num, int index);
|
||||
virtual uint64_t run_server(const std::string& ip, uint32_t port);
|
||||
virtual int stop_server();
|
||||
virtual int finalize_worker();
|
||||
virtual std::vector<uint64_t> get_client_info();
|
||||
virtual int create_client2client_connection(int pserver_timeout_ms,
|
||||
int pserver_connect_timeout_ms,
|
||||
int max_retry);
|
||||
std::shared_ptr<paddle::distributed::PSServer>
|
||||
_server_ptr; // pointer to server
|
||||
std::shared_ptr<paddle::distributed::PSClient>
|
||||
_worker_ptr; // pointer to worker
|
||||
virtual paddle::distributed::PSParameter* get_param();
|
||||
|
||||
private:
|
||||
void init_gflag(const std::string& gflags);
|
||||
paddle::distributed::PSParameter _ps_param;
|
||||
paddle::distributed::PaddlePSEnvironment _ps_env;
|
||||
};
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue