checkpoint at distributed training (#14854)

checkpoint for distributed training.
inference-pre-release-gpu
tangwei12 6 years ago committed by GitHub
parent 07dc5a1506
commit 8b50ad80ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
const std::string method = "SendRPC";
const std::string method = kSendRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
VLOG(100) << "ProcGetResponse";
VLOG(4) << "ProcGetResponse";
framework::Variable* outvar = nullptr;
// get response's trainer_id is not used
int trainer_id;
@ -127,59 +127,74 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name,
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
"/sendrecv.SendRecvService/GetVariable", time_out);
}
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname, int64_t time_out) {
std::string var_name_no_barrier =
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
return _AsyncGetVar(
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
}
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name,
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
"/sendrecv.SendRecvService/GetMonomerVariable", time_out);
}
VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& rpc_path,
int64_t time_out) {
VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const std::string out_varname_val = out_varname;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = "GetRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
framework::AsyncIO(
[var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetResponse;
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method, p_ctx);
platform::RecordRPCEvent record_event(method, p_ctx);
auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = "PrefetchRPC";
const std::string method = kPrefetchRPC;
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "BatchBarrierRPC";
const std::string method = kBatchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
const std::string method = "FetchBarrierRPC";
const std::string method = kFetchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC";
const std::string method = kSendMonomerFetchBarrierRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
s->Prepare(h, time_out);
@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendCompleteRPC";
const std::string method = kSendCompleteRPC;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
const std::string method = "CheckPointNotifyRPC";
const std::string method = kCheckPointNotifyRPC;
VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));

@ -186,8 +186,15 @@ class GRPCClient : public RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
@ -228,11 +235,11 @@ class GRPCClient : public RPCClient {
void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, const std::string& rpc,
int64_t time_out);
VarHandlePtr _AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline);
private:
grpc::CompletionQueue cq_;

@ -136,17 +136,65 @@ class RequestGet final : public RequestBase {
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << varname;
VLOG(4) << "RequestGet " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
class RequestGetNoBarrier final : public RequestBase {
public:
explicit RequestGetNoBarrier(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id =
static_cast<int>(distributed::GrpcMethod::kGetVariableNoBarrier);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGetNoBarrier() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestSend(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
this);

@ -81,6 +81,7 @@ enum class GrpcMethod {
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
kGetVariableNoBarrier,
kGetMonomerVariable,
kGetMonomerBarrier,
};
@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kGetVariableNoBarrier:
return "/sendrecv.SendRecvService/GetVariableNoBarrier";
case GrpcMethod::kGetMonomerVariable:
return "/sendrecv.SendRecvService/GetMonomerVariable";
case GrpcMethod::kGetMonomerBarrier:

@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

@ -23,6 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname;
VLOG(4) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name;
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
// get var from pserver immediately without barriers
string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, without_barrier_piece)) {
var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece);
VLOG(4) << "Get var " << var_name_piece << " with "
<< WITHOUT_BARRIER_MESSAGE;
*outvar = scope_->FindVar(var_name_piece.ToString());
return true;
} else {
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,

@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler {
bool enable_dc_asgd_;
};
class RequestGetNoBarrierHandler final : public RequestHandler {
public:
RequestGetNoBarrierHandler() : RequestHandler(false) {}
virtual ~RequestGetNoBarrierHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {

@ -43,8 +43,15 @@ class RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,

@ -17,8 +17,14 @@ package sendrecv;
option cc_generic_services = @cc_generic_services@;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
rpc GetVariableNoBarrier(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
@ -27,12 +33,17 @@ service SendRecvService {
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
}
// It can be: LoDTensorSelectedRows or NCCL_ID
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// VariableMessage is serialized paddle variable message.
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
@ -49,14 +60,21 @@ message VariableMessage {
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*

@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
FLAGS_rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));

@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
virtual ~ListenAndServOp();
void RunSyncLoop(framework::Executor* executor,
@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_get_no_barrier_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler>

@ -27,30 +27,50 @@ namespace operators {
class RecvOp : public framework::OperatorBase {
public:
RecvOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto outs = Outputs("Out");
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]));
}
if (sync_mode) {
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
@ -79,12 +99,23 @@ This operator can get variables from server side.
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
.SetDefault(true);
AddAttr<std::vector<std::string>>(
"varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
}
};
class RecvOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators

@ -365,7 +365,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
mem_fmt.ndims = axis.size();
for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout)
// regardless physical layout)
}
mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked;
@ -374,7 +374,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
for (int i = nchw_tz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] =
nchw_tz[i]; // logical dimensions (nchw format, regardless physical
// layout)
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;

@ -1696,12 +1696,20 @@ class Program(object):
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = []
# for distribute
# for distribute training
# _is_distributed = True if under distributed training
self._is_distributed = False
# _is_chief = True if the trainer is the first one, usually No.0
self._is_chief = False
self._slice_vars_and_attrs = []
# _parameters_on_pservers records all the parameters distributed on parameter servers.
self._parameters_on_pservers = None
# _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"]
self._endpoints = []
# if current role is parameter server, the _ps_endpoint is its "ip:port"
self._ps_endpoint = None
# trainers_endpoints, it is used for distribution.
self._trainers_endpoints = []
# the distributed lookup table names
self._distributed_lookup_table = None
@property
@ -2232,8 +2240,9 @@ class Program(object):
"Program")
self._is_distributed = other._is_distributed
self._is_chief = other._is_chief
self._slice_vars_and_attrs = other._slice_vars_and_attrs
self._parameters_on_pservers = other._parameters_on_pservers
self._endpoints = other._endpoints
self._ps_endpoint = other._ps_endpoint
self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other):

File diff suppressed because it is too large Load Diff

@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
# NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode)
args.trainers, args.sync_mode, False,
args.current_endpoint)
pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
exe.run(startup_prog)
if need_load and model_dir:
self._load_persistable_vars(exe, model_dir, startup_prog)
fluid.io.load_persistables(exe, model_dir, pserver_prog)
exe.run(pserver_prog)
def run_trainer(self, args):
@ -158,19 +160,46 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
need_save = bool(int(os.getenv("SAVE", "0")))
model_dir = os.getenv("MODEL_DIR", "")
if need_save:
for _ in six.moves.xrange(RUN_STEP):
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()))
if need_save and model_dir:
io.save_persistables(startup_exe, model_dir, trainer_prog)
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor())
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
save_mode = os.getenv("SAVE_MODE", "")
if save_mode == "LOCAL":
if need_save:
for _ in six.moves.xrange(RUN_STEP):
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()))
if need_save and model_dir:
io.save_persistables(startup_exe, model_dir, trainer_prog)
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor(
))
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
elif save_mode == "DIST":
skip_steps = int(os.getenv("SKIP_STEPS"))
loss = None
if need_save:
for idx in six.moves.xrange(8):
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()))
if need_save and model_dir and idx == skip_steps and args.trainer_id == 0:
io.save_persistables(startup_exe, model_dir,
trainer_prog)
else:
for idx in six.moves.xrange(8):
data = get_data()
if idx <= skip_steps:
continue
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(data))
if six.PY2:
print(pickle.dumps(loss.tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
raise Exception("save_mode must be LOCAL or DIST")
if __name__ == "__main__":

@ -75,9 +75,13 @@ def get_loss(cos_q_pt, cos_q_nt):
return avg_cost
def get_optimizer():
# SGD optimizer
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
def get_optimizer(op="sgd"):
if op.upper() == "sgd".upper():
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
elif op.upper() == "adam".upper():
optimizer = fluid.optimizer.Adam(learning_rate=base_lr)
else:
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
return optimizer
@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
inference_program = fluid.default_main_program().clone()
# Optimization
opt = get_optimizer()
opt = os.getenv('OPTIMIZER', 'sgd')
opt = get_optimizer(opt)
opt.minimize(avg_cost)
# Reader

@ -43,7 +43,8 @@ class TestDistRunnerBase(object):
pserver_endpoints,
trainers,
sync_mode,
dc_asgd=False):
dc_asgd=False,
current_endpoint=None):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd
@ -53,7 +54,8 @@ class TestDistRunnerBase(object):
program=main_program,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
sync_mode=sync_mode,
current_endpoint=current_endpoint)
return t
def run_pserver(self, args):

@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase):
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1'
'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'LOCAL',
}
self.check_with_place(
"dist_save_load.py",
delta=0,
check_error_log=False,
need_envs=need_envs)
class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
model_dir = tempfile.mkdtemp()
save_env = {}
save_env["SAVE_MODE"] = "DIST"
save_env["SAVE"] = "1"
save_env["MODEL_DIR"] = model_dir
save_env.update(required_envs)
tr0_var_1, tr1_var_1 = self._run_cluster(model_file, save_env,
check_error_log)
load_env = {}
load_env["LOAD"] = "1"
load_env["MODEL_DIR"] = model_dir
load_env.update(required_envs)
tr0_var_2, tr1_var_2 = self._run_cluster(model_file, load_env,
check_error_log)
shutil.rmtree(model_dir)
train0_1_np = np.array(tr0_var_1)
train1_1_np = np.array(tr1_var_1)
train0_2_np = np.array(tr0_var_2)
train1_2_np = np.array(tr1_var_2)
self.assertAlmostEqual(
train0_1_np.all(), train0_2_np.all(), delta=delta)
self.assertAlmostEqual(
train1_1_np.all(), train1_2_np.all(), delta=delta)
def test_dist(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'DIST',
'OPTIMIZER': 'ADAM',
'SKIP_STEPS': str(np.random.randint(2, 6))
}
self.check_with_place(
"dist_save_load.py",

@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest):
pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in six.moves.xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0])
total_numel = six.moves.reduce(
lambda x, y: x * y, pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual(
total_numel,
six.moves.reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) +
six.moves.reduce(lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape))
vars_ps1 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
self.pserver1_ep)
vars_ps2 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
self.pserver2_ep)
self.assertTrue(vars_ps1)
self.assertTrue(vars_ps2)
for idx in six.moves.xrange(len(vars_ps1)):
total_numel = 0
ps1_numel, ps2_numel = 0, 0
ps1_var = vars_ps1[idx]
if not ps1_var.is_slice:
total_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].slice.shape)
else:
ps2_var = None
for var in vars_ps2:
if var.origin.name == ps1_var.origin.name:
ps2_var = var
break
total_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.slice.shape)
ps2_numel = six.moves.reduce(lambda x, y: x * y,
ps2_var.slice.shape)
self.assertEqual(total_numel, ps1_numel + ps2_numel)
class TestNCCL2Transpile(TranspilerTest):

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save