|
|
|
@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
|
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
|
const auto ch = GetChannel(ep_val);
|
|
|
|
|
SendProcessor* s = new SendProcessor(ch);
|
|
|
|
|
const std::string method = "SendRPC";
|
|
|
|
|
const std::string method = kSendRPC;
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
|
|
|
|
|
|
|
|
|
|
void ProcGetResponse(const VarHandle& var_h,
|
|
|
|
|
const ::grpc::ByteBuffer& ret_msg) {
|
|
|
|
|
VLOG(100) << "ProcGetResponse";
|
|
|
|
|
VLOG(4) << "ProcGetResponse";
|
|
|
|
|
framework::Variable* outvar = nullptr;
|
|
|
|
|
// get response's trainer_id is not used
|
|
|
|
|
int trainer_id;
|
|
|
|
@ -127,59 +127,74 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
|
|
|
|
|
const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Scope& scope,
|
|
|
|
|
const std::string& var_name,
|
|
|
|
|
const std::string& out_varname,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, var_name,
|
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
|
|
|
|
|
"/sendrecv.SendRecvService/GetVariable", time_out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
|
|
|
|
|
const std::string& ep, const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Scope& scope, const std::string& var_name,
|
|
|
|
|
const std::string& out_varname, int64_t time_out) {
|
|
|
|
|
std::string var_name_no_barrier =
|
|
|
|
|
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
|
|
|
|
|
|
|
|
|
|
return _AsyncGetVar(
|
|
|
|
|
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
|
|
|
|
|
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
|
|
|
|
|
const std::string& ep, const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Scope& scope, const std::string& var_name,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, var_name,
|
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
|
|
|
|
|
"/sendrecv.SendRecvService/GetMonomerVariable", time_out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
|
|
|
|
|
const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Scope& scope,
|
|
|
|
|
const std::string& var_name,
|
|
|
|
|
const std::string& rpc_path,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
VarHandlePtr GRPCClient::_AsyncGetVar(
|
|
|
|
|
const std::string& ep, const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Scope& scope, const std::string& method,
|
|
|
|
|
const std::string& var_name, const std::string& out_varname,
|
|
|
|
|
const std::string& rpc_path, int64_t time_out) {
|
|
|
|
|
const platform::DeviceContext* p_ctx = &ctx;
|
|
|
|
|
const std::string ep_val = ep;
|
|
|
|
|
const std::string var_name_val = var_name;
|
|
|
|
|
const std::string out_varname_val = out_varname;
|
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
|
const auto ch = GetChannel(ep_val);
|
|
|
|
|
GetProcessor* s = new GetProcessor(ch);
|
|
|
|
|
const std::string method = "GetRPC";
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
|
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
|
|
|
|
|
// prepare input
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
|
req.set_trainer_id(trainer_id_);
|
|
|
|
|
::grpc::ByteBuffer buf;
|
|
|
|
|
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
|
|
|
|
|
framework::AsyncIO(
|
|
|
|
|
[var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
|
|
|
|
|
// prepare input
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
|
req.set_out_varname(out_varname_val);
|
|
|
|
|
req.set_trainer_id(trainer_id_);
|
|
|
|
|
::grpc::ByteBuffer buf;
|
|
|
|
|
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
|
|
|
|
|
|
|
|
|
|
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
|
|
|
|
|
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
|
|
|
|
|
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
|
|
|
|
|
platform::RecordRPCEvent record_event(method, p_ctx);
|
|
|
|
|
platform::RecordRPCEvent record_event(method, p_ctx);
|
|
|
|
|
|
|
|
|
|
auto call =
|
|
|
|
|
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
auto call =
|
|
|
|
|
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
|
|
const auto ch = GetChannel(ep_val);
|
|
|
|
|
GetProcessor* s = new GetProcessor(ch);
|
|
|
|
|
|
|
|
|
|
const std::string method = "PrefetchRPC";
|
|
|
|
|
const std::string method = kPrefetchRPC;
|
|
|
|
|
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
|
|
|
|
|
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
|
|
|
|
|
const std::string method = "BatchBarrierRPC";
|
|
|
|
|
const std::string method = kBatchBarrierRPC;
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
|
|
|
|
|
const std::string method = "FetchBarrierRPC";
|
|
|
|
|
const std::string method = kFetchBarrierRPC;
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
|
|
|
|
|
const std::string method = "SendMonomerFetchBarrierRPC";
|
|
|
|
|
const std::string method = kSendMonomerFetchBarrierRPC;
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
|
|
|
|
|
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
|
|
|
|
|
const std::string method = "SendCompleteRPC";
|
|
|
|
|
const std::string method = kSendCompleteRPC;
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
|
|
|
|
|
|
|
|
|
|
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
|
|
|
|
|
|
|
|
|
|
const std::string method = "CheckPointNotifyRPC";
|
|
|
|
|
const std::string method = kCheckPointNotifyRPC;
|
|
|
|
|
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
|
|
|
|
|