Fix/distributed proto (#29981)

* rename sendrecv.proto to namespace paddle.distributed

* split ps with distributed
revert-31562-mean
tangwei12 5 years ago committed by GitHub
parent d479ae1725
commit 25f80fd304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -160,6 +160,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
option(WITH_XBYAK "Compile with xbyak support" ON)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_PSCORE "Compile with parameter server support" ${WITH_DISTRIBUTE})
option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE})

@ -160,6 +160,11 @@ if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif()
if(WITH_PSCORE)
add_definitions(-DPADDLE_WITH_PSCORE)
endif()
if(WITH_GRPC)
add_definitions(-DPADDLE_WITH_GRPC)
endif(WITH_GRPC)

@ -274,7 +274,7 @@ if(WITH_BOX_PS)
list(APPEND third_party_deps extern_box_ps)
endif(WITH_BOX_PS)
if (WITH_DISTRIBUTE)
if (WITH_PSCORE)
include(external/snappy)
list(APPEND third_party_deps extern_snappy)

@ -1,7 +1,4 @@
if (WITH_PSLIB)
return()
endif()
if(NOT WITH_DISTRIBUTE)
if(NOT WITH_PSCORE)
return()
endif()

@ -69,24 +69,24 @@ class ObjectFactory {
};
typedef std::map<std::string, ObjectFactory *> FactoryMap;
typedef std::map<std::string, FactoryMap> BaseClassMap;
typedef std::map<std::string, FactoryMap> PsCoreClassMap;
#ifdef __cplusplus
extern "C" {
#endif
inline BaseClassMap &global_factory_map() {
static BaseClassMap *base_class = new BaseClassMap();
inline PsCoreClassMap &global_factory_map() {
static PsCoreClassMap *base_class = new PsCoreClassMap();
return *base_class;
}
#ifdef __cplusplus
}
#endif
inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
inline PsCoreClassMap &global_factory_map_cpp() { return global_factory_map(); }
// typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap;
#define REGISTER_REGISTERER(base_class) \
#define REGISTER_PSCORE_REGISTERER(base_class) \
class base_class##Registerer { \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
};
#define REGISTER_CLASS(clazz, name) \
#define REGISTER_PSCORE_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \
public: \
Any NewInstance() { return Any(new name()); } \
@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
#define CREATE_PSCORE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
} // namespace distributed

@ -86,7 +86,7 @@ message SparseTableParameter {
message ServerServiceParameter {
optional string server_class = 1 [ default = "BrpcPsServer" ];
optional string client_class = 2 [ default = "BrpcPsClient" ];
optional string service_class = 3 [ default = "PsService" ];
optional string service_class = 3 [ default = "BrpcPsService" ];
optional uint32 start_server_port = 4
[ default = 0 ]; // will find a avaliable port from it
optional uint32 server_thread_num = 5 [ default = 12 ];

@ -17,8 +17,8 @@
#include <sstream>
#include <string>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
void DownpourPsClientService::service(
::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) {
const PsRequestMessage *request, PsResponseMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
int ret = _client->handle_client2client_msg(
request->cmd_id(), request->client_id(), request->data());

@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:

File diff suppressed because it is too large Load Diff

@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class PsService;
class BrpcPsService;
typedef int32_t (PsService::*serviceHandlerFunc)(
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl);
class PsService : public PsBaseService {
class BrpcPsService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:

@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::LOD_TENSOR);
var_msg->set_type(::paddle::distributed::LOD_TENSOR);
const framework::LoD lod = tensor->lod();
if (lod.size() > 0) {
var_msg->set_lod_level(lod.size());
@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::SELECTED_ROWS);
var_msg->set_type(::paddle::distributed::SELECTED_ROWS);
var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data();
@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::LOD_TENSOR) {
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
PADDLE_ENFORCE_NE(var, nullptr,
platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::LOD_TENSOR) {
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}

@ -44,8 +44,8 @@ class DeviceContext;
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,

@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get());
::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
for (const auto& param : params) {
closure->request(i)->add_params(param);
}
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),

@ -35,8 +35,8 @@ limitations under the License. */
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc;

@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
class HeterService;
typedef int32_t (HeterService::*serviceHandlerFunc)(
@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
HeterServiceHandler;
class HeterService : public ::paddle::PsService {
class HeterService : public ::paddle::distributed::PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const ::paddle::PsRequestMessage* request,
::paddle::PsResponseMessage* response,
const PsRequestMessage* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");

@ -13,9 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h"
#include <map>
#include "brpc/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
@ -23,7 +21,7 @@
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
int32_t PSClient::configure(
const PSParameter &config,
@ -43,7 +41,7 @@ int32_t PSClient::configure(
const auto &work_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_CLASS(
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor());
@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client = CREATE_CLASS(PSClient, service_param.client_class());
PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();

@ -28,6 +28,9 @@
namespace paddle {
namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
@ -206,7 +209,7 @@ class PSClient {
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; //处理client2client消息
};
REGISTER_REGISTERER(PSClient);
REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
public:

@ -13,7 +13,7 @@
// limitations under the License.
syntax = "proto2";
package paddle;
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;

@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
@ -20,8 +21,8 @@
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSServer, BrpcPsServer);
REGISTER_CLASS(PsBaseService, PsService);
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
}
const auto &service_param = config.downpour_server_param().service_param();
PSServer *server = CREATE_CLASS(PSServer, service_param.server_class());
PSServer *server =
CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class();
@ -70,7 +72,7 @@ int32_t PSServer::configure(
uint32_t global_step_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_CLASS(
auto *table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() ==

@ -46,6 +46,8 @@ namespace paddle {
namespace distributed {
class Table;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
class PSServer {
public:
@ -107,7 +109,7 @@ class PSServer {
platform::Place place_ = platform::CPUPlace();
};
REGISTER_REGISTERER(PSServer);
REGISTER_PSCORE_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack;
@ -141,8 +143,8 @@ class PsBaseService : public PsService {
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response, int err_code,
@ -159,7 +161,7 @@ class PsBaseService : public PsService {
PSServer *_server;
const ServerParameter *_config;
};
REGISTER_REGISTERER(PsBaseService);
REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory {
public:

@ -28,6 +28,10 @@ limitations under the License. */
namespace paddle {
namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
using paddle::distributed::PsService;
class PSCore {
public:
explicit PSCore() {}

@ -165,6 +165,6 @@ class ValueAccessor {
std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map;
};
REGISTER_REGISTERER(ValueAccessor);
REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed
} // namespace paddle

@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h"
@ -27,14 +28,14 @@
namespace paddle {
namespace distributed {
REGISTER_CLASS(Table, CommonDenseTable);
REGISTER_CLASS(Table, CommonSparseTable);
REGISTER_CLASS(Table, SparseGeoTable);
REGISTER_CLASS(Table, BarrierTable);
REGISTER_CLASS(Table, TensorTable);
REGISTER_CLASS(Table, DenseTensorTable);
REGISTER_CLASS(Table, GlobalStepTable);
REGISTER_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
int32_t TableManager::initialize() {
static bool initialized = false;
@ -61,8 +62,8 @@ int32_t Table::initialize_accessor() {
<< _config.table_id();
return -1;
}
auto *accessor =
CREATE_CLASS(ValueAccessor,
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
_config.accessor().accessor_class()) if (accessor == NULL) {
LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id()
<< ", accessor_name:" << _config.accessor().accessor_class();

@ -127,7 +127,7 @@ class Table {
float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor;
};
REGISTER_REGISTERER(Table);
REGISTER_PSCORE_REGISTERER(Table);
class TableManager {
public:

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <string>
#include <thread> // NOLINT
@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
@ -244,7 +245,8 @@ void RunBrpcPushDense() {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) {
if (closure->check_response(
i, paddle::distributed::PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save