feature/DC asgd (#12722)

* wip

* add ref_by_trainer_id op

* ready to test

* fix ref inputs

* refine rpc_op_handle

* fix merge bug
fix_recordio_link
Wu Yi 7 years ago committed by GitHub
parent c3cbf0b8ef
commit 306236c2c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString() if (ir::IsControlDepVar(*in->Node())) {
if (ir::IsControlDepVar(*in->Node())) { // HACK
continue; continue;
} }
if (in->GeneratedOp()) { if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p)); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
} }
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); this->RunAndRecordEvent([this] {
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(),
// lock. place_);
op_->Run(*tmp_scope, place_); });
} }
std::string RPCOpHandle::Name() const { return name_; } std::string RPCOpHandle::Name() const { return name_; }

@ -85,8 +85,10 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() { void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>() ::paddle::operators::distributed::GRPCClient>(0)
->SendComplete(); ->SendComplete();
#endif #endif
} }

@ -38,9 +38,10 @@ class CheckpointNotifyOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table"); std::string lookup_table_name = Attr<std::string>("lookup_table");
int trainer_id = Attr<int>("trainer_id");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < epmap.size(); i++) { for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir = auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
@ -63,6 +64,7 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"dir", "(string, default '') indicate the folder checkpoint will use"); "dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table", AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name"); "(string, default '') the lookup table name");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
CheckpointNotify operator CheckpointNotify operator

@ -79,7 +79,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
@ -105,7 +105,10 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) { const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar); // get response's trainer_id is not used
int trainer_id;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
&trainer_id);
} }
template <typename T> template <typename T>
@ -135,6 +138,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf; ::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf); RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);

@ -34,8 +34,8 @@ namespace distributed {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg, const std::string& out_name,
const std::string& out_name) { const int trainer_id) {
platform::RecordRPCEvent record_event("serial", &ctx); platform::RecordRPCEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU // Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed. // the CPU buffer need to be freed.
@ -45,6 +45,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
size_t payload_size; size_t payload_size;
request.set_varname(name); request.set_varname(name);
request.set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only // Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS // 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the // servers the trainer's profiling state so that PS can follow the
@ -147,11 +148,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var) { framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial", &ctx); platform::RecordRPCEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx); operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar(); *var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
} }
} // namespace distributed } // namespace distributed

@ -38,12 +38,13 @@ typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string()); const std::string& out_varname = std::string(),
const int trainer_id = 0);
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var); framework::Variable** var, int* trainer_id);
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators

@ -102,9 +102,10 @@ class RequestSend final : public RequestBase {
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar(); auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
@ -133,13 +134,14 @@ class RequestGet final : public RequestBase {
void Process() override { void Process() override {
// proc request. // proc request.
std::string varname = request_.varname(); std::string varname = request_.varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << varname; VLOG(4) << "RequestGet " << varname;
auto scope = request_handler_->scope(); auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname); auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
if (outvar) { if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
@ -179,6 +181,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process... // prefetch process...
std::string in_var_name = request_->Varname(); std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname(); std::string out_var_name = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name;
@ -187,7 +190,8 @@ class RequestPrefetch final : public RequestBase {
// out var must be created in local scope! // out var must be created in local scope!
framework::Variable* outvar = scope->Var(out_var_name); framework::Variable* outvar = scope->Var(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name); request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
@ -225,12 +229,13 @@ class RequestCheckpointNotify final : public RequestBase {
std::string checkpoint_notify = request_->Varname(); std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname(); std::string checkpoint_dir = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir; << ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir); trainer_id, checkpoint_dir);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }

@ -293,6 +293,14 @@ int GRPCVariableResponse::Parse(Source* source) {
} }
break; break;
} }
case sendrecv::VariableMessage::kTrainerIdFieldNumber: {
uint64_t trainer_id = 0;
if (!input.ReadVarint64(&trainer_id)) {
return tag;
}
meta_.set_trainer_id(trainer_id);
break;
}
default: { default: {
// Unknown tag, return unknown error. // Unknown tag, return unknown error.
return -1; return -1;

@ -190,6 +190,7 @@ class RequestHandler {
// } // }
virtual bool Handle(const std::string& varname, framework::Scope* scope, virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") = 0; const std::string& out_var_name = "") = 0;
protected: protected:

@ -36,6 +36,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
@ -76,6 +77,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname; VLOG(4) << "RequestGetHandler:" << varname;
if (sync_mode_) { if (sync_mode_) {
@ -88,6 +90,19 @@ bool RequestGetHandler::Handle(const std::string& varname,
} }
} else { } else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distributed_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
*outvar = scope_->FindVar(varname); *outvar = scope_->FindVar(varname);
} }
} }
@ -98,6 +113,7 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestPrefetchHandler " << varname; VLOG(4) << "RequestPrefetchHandler " << varname;
@ -113,6 +129,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
checkpoint_notify_id != -1, checkpoint_notify_id != -1,

@ -36,20 +36,34 @@ namespace distributed {
class RequestSendHandler final : public RequestHandler { class RequestSendHandler final : public RequestHandler {
public: public:
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestSendHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
private:
bool enable_dc_asgd_;
}; };
class RequestGetHandler final : public RequestHandler { class RequestGetHandler final : public RequestHandler {
public: public:
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestGetHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestGetHandler() {} virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
private:
bool enable_dc_asgd_;
}; };
class RequestPrefetchHandler final : public RequestHandler { class RequestPrefetchHandler final : public RequestHandler {
@ -58,6 +72,7 @@ class RequestPrefetchHandler final : public RequestHandler {
virtual ~RequestPrefetchHandler() {} virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
}; };
@ -70,6 +85,7 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual ~RequestCheckpointHandler() {} virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
private: private:

@ -24,6 +24,7 @@ namespace distributed {
std::once_flag RPCClient::init_flag_; std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr); std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
int RPCClient::trainer_id_ = 0;
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators

@ -72,14 +72,15 @@ class RPCClient {
virtual bool Wait() = 0; virtual bool Wait() = 0;
template <typename T> template <typename T>
static RPCClient* GetInstance() { static RPCClient* GetInstance(int trainer_id) {
std::call_once(init_flag_, &RPCClient::Init<T>); std::call_once(init_flag_, &RPCClient::Init<T>, trainer_id);
return rpc_client_.get(); return rpc_client_.get();
} }
// Init is called by GetInstance. // Init is called by GetInstance.
template <typename T> template <typename T>
static void Init() { static void Init(int trainer_id) {
trainer_id_ = trainer_id;
if (rpc_client_.get() == nullptr) { if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new T()); rpc_client_.reset(new T());
rpc_client_->InitImpl(); rpc_client_->InitImpl();
@ -88,6 +89,8 @@ class RPCClient {
protected: protected:
virtual void InitImpl() {} virtual void InitImpl() {}
// each trainer have exact one trainer id, it should be static
static int trainer_id_;
private: private:
static std::once_flag init_flag_; static std::once_flag init_flag_;

@ -125,7 +125,7 @@ TEST(PREFETCH, CPU) {
g_req_handler.reset(new distributed::RequestPrefetchHandler(true)); g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
std::thread server_thread(StartServer, distributed::kRequestPrefetch); std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
@ -165,7 +165,7 @@ TEST(COMPLETE, CPU) {
g_req_handler.reset(new distributed::RequestSendHandler(true)); g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE(client != nullptr); PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend); std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();

@ -79,6 +79,7 @@ message VariableMessage {
// server stops profiling and generates a profile to /tmp/profile_ps_* // server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2. // when profile switches from 1 to 2.
int64 profile = 11; int64 profile = 11;
int64 trainer_id = 12;
} }
message VoidMessage {} message VoidMessage {}

@ -92,6 +92,8 @@ class VariableResponse {
return scope_->FindVar(meta_.varname()); return scope_->FindVar(meta_.varname());
} }
int GetTrainerId() { return static_cast<int>(meta_.trainer_id()); }
protected: protected:
bool ReadRaw(::google::protobuf::io::CodedInputStream* input, bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place, const platform::DeviceContext& dev_ctx, platform::Place place,

@ -37,7 +37,8 @@ class FetchBarrierOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
@ -61,6 +62,7 @@ This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent. the Parameter Server would knew all variables have been sent.
)DOC"); )DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints", AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.") "Server endpoints to send variables to.")

@ -61,7 +61,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;

@ -218,23 +218,26 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope) const { framework::Scope *recv_scope) const {
VLOG(2) << "RunAsyncLoop"; VLOG(2) << "RunAsyncLoop";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_to_block_id_str = auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id"); Attr<std::vector<std::string>>("grad_to_block_id");
for (const auto &grad_and_id : grad_to_block_id_str) { DoubleFindMap<std::string, int32_t> grad_to_block_id;
auto append_block_maps = [](DoubleFindMap<std::string, int32_t> *out_map,
const std::string &grad_and_id) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces); split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; VLOG(3) << "after split, key = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); PADDLE_ENFORCE_EQ(out_map->count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id; (*out_map)[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0]; };
for (const auto &grad_and_id : grad_to_block_id_str) {
append_block_maps(&grad_to_block_id, grad_and_id);
} }
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
@ -244,15 +247,22 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list.push_back(blkid); block_list.push_back(blkid);
} }
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed // execute global block if needed, block id 1 in the program is global
if (block_list[0] == 1 && id_to_grad.count(1) == 0) { // block if it's not bind to a grad var for it's update.
if (block_list[0] == 1 &&
grad_to_block_id.find_value(static_cast<int32_t>(1)) ==
grad_to_block_id.end()) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope); executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
} }
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx; grad_to_prepared_ctx, param_to_prepared_ctx;
for (size_t i = 0; i < block_list.size(); ++i) { for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i]; auto blkid = block_list[i];
auto it = grad_to_block_id.find_value(blkid);
if (it != grad_to_block_id.end()) {
grad_to_prepared_ctx[it->first] = optimize_prepared[i];
}
} }
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
@ -315,6 +325,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode"); bool sync_mode = Attr<bool>("sync_mode");
bool dc_sgd = Attr<bool>("dc_asgd");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X"); auto inputs = Inputs("X");
@ -328,8 +339,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(new distributed::RequestSendHandler(sync_mode)); request_send_handler_.reset(
request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode)); new distributed::RequestSendHandler(sync_mode, dc_sgd));
request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, dc_sgd));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
@ -443,6 +456,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.")
.SetDefault(false);
AddAttr<std::vector<framework::BlockDesc *>>( AddAttr<std::vector<framework::BlockDesc *>>(
kOptimizeBlocks, "Optimize blocks to run on server side.") kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({}); .SetDefault({});

@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic> #include <atomic>
#include <set> #include <set>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
@ -37,6 +38,17 @@ constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void RunServer(std::shared_ptr<distributed::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> {
public:
typename std::unordered_map<TKey, TValue>::iterator find_value(TValue v) {
return std::find_if(this->begin(), this->end(),
[&v](const std::pair<const std::string, int> p) {
return p.second == v;
});
}
};
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
ListenAndServOp(const std::string& type, ListenAndServOp(const std::string& type,

@ -42,7 +42,8 @@ class PrefetchOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
@ -69,6 +70,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) result " "(LoDTensor) result "
"to be fetched from parameter server") "to be fetched from parameter server")
.AsDuplicable(); .AsDuplicable();
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"

@ -42,7 +42,8 @@ class RecvOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
@ -73,6 +74,7 @@ This operator can get variables from server side.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode", AddAttr<int>("sync_mode",
"(int, default 0)" "(int, default 0)"
"sync recv or async recv.") "sync recv or async recv.")

@ -0,0 +1,79 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
#include <string>
namespace paddle {
namespace operators {
class RefByTrainerIdOp : public framework::OperatorWithKernel {
public:
RefByTrainerIdOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"),
"Input(X) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("TrainerId"),
"Input(TrainerId) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("TrainerId").size(), 1,
"TrainerId should be a scalar.");
// Out's shape is determined at runtime.
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
}
};
class RefByTrainerIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor list.").AsDuplicable();
AddInput("TrainerId", "(Tensor) Scalar int, the trainer id runtime value.");
AddOutput("Out", "(Tensor) Return one tensor reference of X[trainer_id]");
AddComment(R"DOC(
**RefByTrainerId operator**
Return a reference of a tensor, using trainer_id as the index to find from the input.
$$Out = X[TrainerId]$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ref_by_trainer_id, ops::RefByTrainerIdOp,
ops::RefByTrainerIdOpMaker);
REGISTER_OP_CPU_KERNEL(
ref_by_trainer_id,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, double>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int64_t>);

@ -0,0 +1,26 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
REGISTER_OP_CUDA_KERNEL(
ref_by_trainer_id,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int64_t>);

@ -0,0 +1,49 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RefByTrainerIdKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out = context.Output<framework::Tensor>("Out");
auto in_list = context.MultiInput<framework::Tensor>("X");
auto* trainer_id_t = context.Input<framework::Tensor>("TrainerId");
int64_t trainer_id;
auto* trainer_id_data = trainer_id_t->data<int64_t>();
if (platform::is_gpu_place(context.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
auto stream = context.cuda_device_context().stream();
memory::Copy<>(platform::CPUPlace(), &trainer_id,
boost::get<platform::CUDAPlace>(context.GetPlace()),
trainer_id_data, sizeof(int64_t), stream);
#endif
} else {
trainer_id = *trainer_id_data;
}
printf("after get trainer_id %lu\n", trainer_id);
PADDLE_ENFORCE_LT(trainer_id, in_list.size());
out->mutable_data<T>(context.GetPlace());
out->ShareDataWith(*(in_list[trainer_id]));
}
};
} // namespace operators
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save