Fix/distributed proto (#29981)

* rename sendrecv.proto to namespace paddle.distributed

* split ps with distributed
revert-31562-mean
tangwei12 5 years ago committed by GitHub
parent d479ae1725
commit 25f80fd304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -160,6 +160,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
option(WITH_XBYAK "Compile with xbyak support" ON) option(WITH_XBYAK "Compile with xbyak support" ON)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_PSCORE "Compile with parameter server support" ${WITH_DISTRIBUTE})
option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE}) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE})

@ -160,6 +160,11 @@ if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE) add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif() endif()
if(WITH_PSCORE)
add_definitions(-DPADDLE_WITH_PSCORE)
endif()
if(WITH_GRPC) if(WITH_GRPC)
add_definitions(-DPADDLE_WITH_GRPC) add_definitions(-DPADDLE_WITH_GRPC)
endif(WITH_GRPC) endif(WITH_GRPC)

@ -274,7 +274,7 @@ if(WITH_BOX_PS)
list(APPEND third_party_deps extern_box_ps) list(APPEND third_party_deps extern_box_ps)
endif(WITH_BOX_PS) endif(WITH_BOX_PS)
if (WITH_DISTRIBUTE) if (WITH_PSCORE)
include(external/snappy) include(external/snappy)
list(APPEND third_party_deps extern_snappy) list(APPEND third_party_deps extern_snappy)

@ -1,7 +1,4 @@
if (WITH_PSLIB) if(NOT WITH_PSCORE)
return()
endif()
if(NOT WITH_DISTRIBUTE)
return() return()
endif() endif()

@ -69,24 +69,24 @@ class ObjectFactory {
}; };
typedef std::map<std::string, ObjectFactory *> FactoryMap; typedef std::map<std::string, ObjectFactory *> FactoryMap;
typedef std::map<std::string, FactoryMap> BaseClassMap; typedef std::map<std::string, FactoryMap> PsCoreClassMap;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
inline BaseClassMap &global_factory_map() { inline PsCoreClassMap &global_factory_map() {
static BaseClassMap *base_class = new BaseClassMap(); static PsCoreClassMap *base_class = new PsCoreClassMap();
return *base_class; return *base_class;
} }
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } inline PsCoreClassMap &global_factory_map_cpp() { return global_factory_map(); }
// typedef pa::Any Any; // typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap; // typedef ::FactoryMap FactoryMap;
#define REGISTER_REGISTERER(base_class) \ #define REGISTER_PSCORE_REGISTERER(base_class) \
class base_class##Registerer { \ class base_class##Registerer { \
public: \ public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \ static base_class *CreateInstanceByName(const ::std::string &name) { \
@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \ } \
}; };
#define REGISTER_CLASS(clazz, name) \ #define REGISTER_PSCORE_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \ class ObjectFactory##name : public ObjectFactory { \
public: \ public: \
Any NewInstance() { return Any(new name()); } \ Any NewInstance() { return Any(new name()); } \
@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \ } \
void register_factory_##name() __attribute__((constructor)); void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \ #define CREATE_PSCORE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name); base_class##Registerer::CreateInstanceByName(name);
} // namespace distributed } // namespace distributed

@ -86,7 +86,7 @@ message SparseTableParameter {
message ServerServiceParameter { message ServerServiceParameter {
optional string server_class = 1 [ default = "BrpcPsServer" ]; optional string server_class = 1 [ default = "BrpcPsServer" ];
optional string client_class = 2 [ default = "BrpcPsClient" ]; optional string client_class = 2 [ default = "BrpcPsClient" ];
optional string service_class = 3 [ default = "PsService" ]; optional string service_class = 3 [ default = "BrpcPsService" ];
optional uint32 start_server_port = 4 optional uint32 start_server_port = 4
[ default = 0 ]; // will find a avaliable port from it [ default = 0 ]; // will find a avaliable port from it
optional uint32 server_thread_num = 5 [ default = 12 ]; optional uint32 server_thread_num = 5 [ default = 12 ];

@ -17,8 +17,8 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "Eigen/Dense" #include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
void DownpourPsClientService::service( void DownpourPsClientService::service(
::google::protobuf::RpcController *controller, ::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request, PsResponseMessage *response,
::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) { ::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
int ret = _client->handle_client2client_msg( int ret = _client->handle_client2client_msg(
request->cmd_id(), request->client_id(), request->data()); request->cmd_id(), request->client_id(), request->data());

@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
return 0; return 0;
} }
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request,
::paddle::PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
protected: protected:

File diff suppressed because it is too large Load Diff

@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels; std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
}; };
class PsService; class BrpcPsService;
typedef int32_t (PsService::*serviceHandlerFunc)( typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
class PsService : public PsBaseService { class BrpcPsService : public PsBaseService {
public: public:
virtual int32_t initialize() override; virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request,
::paddle::PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
private: private:

@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg, const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) { butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::LOD_TENSOR); var_msg->set_type(::paddle::distributed::LOD_TENSOR);
const framework::LoD lod = tensor->lod(); const framework::LoD lod = tensor->lod();
if (lod.size() > 0) { if (lod.size() > 0) {
var_msg->set_lod_level(lod.size()); var_msg->set_lod_level(lod.size());
@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows(); auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::SELECTED_ROWS); var_msg->set_type(::paddle::distributed::SELECTED_ROWS);
var_msg->set_slr_height(slr->height()); var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data(); auto* var_data = var_msg->mutable_data();
@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++recv_var_index) { ++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index); const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname()); auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::LOD_TENSOR) { if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx); DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) { } else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
} }
} }
@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
PADDLE_ENFORCE_NE(var, nullptr, PADDLE_ENFORCE_NE(var, nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname())); "Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::LOD_TENSOR) { if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx); DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) { } else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
} }
} }

@ -44,8 +44,8 @@ class DeviceContext;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf( void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name, const std::string& message_name,

@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response; distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment(); auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get()); ::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf( distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer); &request, &request_io_buffer);
@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
for (const auto& param : params) { for (const auto& param : params) {
closure->request(i)->add_params(param); closure->request(i)->add_params(param);
} }
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get()); ::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms( closure->cntl(i)->set_timeout_ms(
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),

@ -35,8 +35,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc; typedef std::function<void(void*)> HeterRpcCallbackFunc;

@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
class HeterService; class HeterService;
typedef int32_t (HeterService::*serviceHandlerFunc)( typedef int32_t (HeterService::*serviceHandlerFunc)(
@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)> typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
HeterServiceHandler; HeterServiceHandler;
class HeterService : public ::paddle::PsService { class HeterService : public ::paddle::distributed::PsService {
public: public:
HeterService() { HeterService() {
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker; _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
virtual ~HeterService() {} virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller, virtual void service(::google::protobuf::RpcController* controller,
const ::paddle::PsRequestMessage* request, const PsRequestMessage* request,
::paddle::PsResponseMessage* response, PsResponseMessage* response,
::google::protobuf::Closure* done) { ::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-"); std::string log_label("ReceiveCmd-");

@ -13,9 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/ps_client.h"
#include <map> #include <map>
#include "brpc/server.h" #include "brpc/server.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
@ -23,7 +21,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
REGISTER_CLASS(PSClient, BrpcPsClient); REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
int32_t PSClient::configure( int32_t PSClient::configure(
const PSParameter &config, const PSParameter &config,
@ -43,7 +41,7 @@ int32_t PSClient::configure(
const auto &work_param = _config.worker_param().downpour_worker_param(); const auto &work_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) { for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_CLASS( auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor, ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class()); work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor()); accessor->configure(work_param.downpour_table_param(i).accessor());
@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
} }
const auto &service_param = config.downpour_server_param().service_param(); const auto &service_param = config.downpour_server_param().service_param();
PSClient *client = CREATE_CLASS(PSClient, service_param.client_class()); PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
if (client == NULL) { if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:" LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class(); << service_param.client_class();

@ -28,6 +28,9 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack; typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure { class PSClientClosure : public google::protobuf::Closure {
public: public:
@ -206,7 +209,7 @@ class PSClient {
std::unordered_map<int32_t, MsgHandlerFunc> std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; //处理client2client消息 _msg_handler_map; //处理client2client消息
}; };
REGISTER_REGISTERER(PSClient); REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory { class PSClientFactory {
public: public:

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
syntax = "proto2"; syntax = "proto2";
package paddle; package paddle.distributed;
option cc_generic_services = true; option cc_generic_services = true;
option cc_enable_arenas = true; option cc_enable_arenas = true;

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/service/server.h" #include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
@ -20,8 +21,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
REGISTER_CLASS(PSServer, BrpcPsServer); REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_CLASS(PsBaseService, PsService); REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) { PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param(); const auto &config = ps_config.server_param();
@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
} }
const auto &service_param = config.downpour_server_param().service_param(); const auto &service_param = config.downpour_server_param().service_param();
PSServer *server = CREATE_CLASS(PSServer, service_param.server_class()); PSServer *server =
CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
if (server == NULL) { if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:" LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class(); << service_param.server_class();
@ -70,7 +72,7 @@ int32_t PSServer::configure(
uint32_t global_step_table = UINT32_MAX; uint32_t global_step_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_CLASS( auto *table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class()); Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() == if (downpour_param.downpour_table_param(i).table_class() ==

@ -46,6 +46,8 @@ namespace paddle {
namespace distributed { namespace distributed {
class Table; class Table;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
class PSServer { class PSServer {
public: public:
@ -107,7 +109,7 @@ class PSServer {
platform::Place place_ = platform::CPUPlace(); platform::Place place_ = platform::CPUPlace();
}; };
REGISTER_REGISTERER(PSServer); REGISTER_PSCORE_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack; typedef std::function<void(void *)> PServerCallBack;
@ -141,8 +143,8 @@ class PsBaseService : public PsService {
return 0; return 0;
} }
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request,
::paddle::PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0; ::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response, int err_code, virtual void set_response_code(PsResponseMessage &response, int err_code,
@ -159,7 +161,7 @@ class PsBaseService : public PsService {
PSServer *_server; PSServer *_server;
const ServerParameter *_config; const ServerParameter *_config;
}; };
REGISTER_REGISTERER(PsBaseService); REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory { class PSServerFactory {
public: public:

@ -28,6 +28,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
using paddle::distributed::PsService;
class PSCore { class PSCore {
public: public:
explicit PSCore() {} explicit PSCore() {}

@ -165,6 +165,6 @@ class ValueAccessor {
std::unordered_map<int, std::shared_ptr<struct DataConverter>> std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map; _data_coverter_map;
}; };
REGISTER_REGISTERER(ValueAccessor); REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp> #include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp> #include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h" #include "glog/logging.h"
@ -27,14 +28,14 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
REGISTER_CLASS(Table, CommonDenseTable); REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_CLASS(Table, CommonSparseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
REGISTER_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_CLASS(Table, TensorTable); REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_CLASS(Table, DenseTensorTable); REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_CLASS(Table, GlobalStepTable); REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_CLASS(ValueAccessor, CommMergeAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
int32_t TableManager::initialize() { int32_t TableManager::initialize() {
static bool initialized = false; static bool initialized = false;
@ -61,8 +62,8 @@ int32_t Table::initialize_accessor() {
<< _config.table_id(); << _config.table_id();
return -1; return -1;
} }
auto *accessor = auto *accessor = CREATE_PSCORE_CLASS(
CREATE_CLASS(ValueAccessor, ValueAccessor,
_config.accessor().accessor_class()) if (accessor == NULL) { _config.accessor().accessor_class()) if (accessor == NULL) {
LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id() LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id()
<< ", accessor_name:" << _config.accessor().accessor_class(); << ", accessor_name:" << _config.accessor().accessor_class();

@ -127,7 +127,7 @@ class Table {
float *_global_lr = nullptr; float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor; std::shared_ptr<ValueAccessor> _value_accesor;
}; };
REGISTER_REGISTERER(Table); REGISTER_PSCORE_REGISTERER(Table);
class TableManager { class TableManager {
public: public:

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param(); server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto = ::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param(); downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService"); server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0); server_service_proto->set_start_server_port(0);
@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param(); server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto = ::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param(); downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService"); server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0); server_service_proto->set_start_server_port(0);
@ -244,7 +245,8 @@ void RunBrpcPushDense() {
int ret = 0; int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) { for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) { if (closure->check_response(
i, paddle::distributed::PS_PUSH_DENSE_TABLE) != 0) {
ret = -1; ret = -1;
break; break;
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save