|
|
@ -62,7 +62,7 @@ VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
|
|
|
|
const std::string var_name_val = var_name;
|
|
|
|
const std::string var_name_val = var_name;
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
const auto ch_ptr = GetChannel(ep_val);
|
|
|
|
const auto ch_ptr = GetChannel(ep_val);
|
|
|
|
const std::string method = "SendRPC";
|
|
|
|
const std::string method = kSendRPC;
|
|
|
|
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([=] {
|
|
|
|
framework::AsyncIO([=] {
|
|
|
@ -156,15 +156,18 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
|
|
|
|
const platform::DeviceContext& ctx,
|
|
|
|
const platform::DeviceContext& ctx,
|
|
|
|
const framework::Scope& scope,
|
|
|
|
const framework::Scope& scope,
|
|
|
|
const std::string& var_name,
|
|
|
|
const std::string& var_name,
|
|
|
|
|
|
|
|
const std::string& out_var_name,
|
|
|
|
const std::string& method_name,
|
|
|
|
const std::string& method_name,
|
|
|
|
int64_t time_out) {
|
|
|
|
int64_t time_out) {
|
|
|
|
const platform::DeviceContext* p_ctx = &ctx;
|
|
|
|
const platform::DeviceContext* p_ctx = &ctx;
|
|
|
|
const std::string ep_val = ep;
|
|
|
|
const std::string ep_val = ep;
|
|
|
|
const std::string var_name_val = var_name;
|
|
|
|
const std::string var_name_val = var_name;
|
|
|
|
|
|
|
|
const std::string out_varname_val = out_var_name;
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
const auto ch_ptr = GetChannel(ep_val);
|
|
|
|
const auto ch_ptr = GetChannel(ep_val);
|
|
|
|
const std::string method = "GetRPC";
|
|
|
|
const std::string method = kGetRPC;
|
|
|
|
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
VarHandlePtr var_h(
|
|
|
|
|
|
|
|
new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
|
|
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([=] {
|
|
|
|
framework::AsyncIO([=] {
|
|
|
|
auto ch_ctx = ch_ptr->Pop();
|
|
|
|
auto ch_ctx = ch_ptr->Pop();
|
|
|
@ -175,6 +178,7 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
|
|
|
|
|
|
|
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
|
|
|
|
req.set_out_varname(out_varname_val);
|
|
|
|
req.set_trainer_id(trainer_id_);
|
|
|
|
req.set_trainer_id(trainer_id_);
|
|
|
|
|
|
|
|
|
|
|
|
google::protobuf::Closure* done = brpc::NewCallback(
|
|
|
|
google::protobuf::Closure* done = brpc::NewCallback(
|
|
|
@ -182,8 +186,10 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
|
|
|
|
|
|
|
|
|
|
|
|
platform::RecordRPCEvent record_event(method, p_ctx);
|
|
|
|
platform::RecordRPCEvent record_event(method, p_ctx);
|
|
|
|
|
|
|
|
|
|
|
|
if (method_name == "GetMonomerVariable") {
|
|
|
|
if (method_name == kGetMonomerRPC) {
|
|
|
|
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
|
|
|
|
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
|
|
|
|
|
|
|
|
} else if (method_name == kGetNoBarrierRPC) {
|
|
|
|
|
|
|
|
ch_ctx->stub->GetVariableNoBarrier(cntl, &req, response, done);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
ch_ctx->stub->GetVariable(cntl, &req, response, done);
|
|
|
|
ch_ctx->stub->GetVariable(cntl, &req, response, done);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -198,25 +204,39 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
|
|
|
|
return var_h;
|
|
|
|
return var_h;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr BRPCClient::AsyncGetVarNoBarrier(
|
|
|
|
|
|
|
|
const std::string& ep, const platform::DeviceContext& ctx,
|
|
|
|
|
|
|
|
const framework::Scope& scope, const std::string& var_name,
|
|
|
|
|
|
|
|
const std::string& out_var_name, int64_t time_out) {
|
|
|
|
|
|
|
|
std::string var_name_no_barrier =
|
|
|
|
|
|
|
|
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, var_name_no_barrier, out_var_name,
|
|
|
|
|
|
|
|
kGetNoBarrierRPC, time_out);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
|
|
|
|
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
|
|
|
|
const std::string& ep, const platform::DeviceContext& ctx,
|
|
|
|
const std::string& ep, const platform::DeviceContext& ctx,
|
|
|
|
const framework::Scope& scope, const std::string& var_name,
|
|
|
|
const framework::Scope& scope, const std::string& var_name,
|
|
|
|
int64_t time_out) {
|
|
|
|
int64_t time_out) {
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out);
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, var_name, var_name, kGetMonomerRPC,
|
|
|
|
|
|
|
|
time_out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
|
|
|
|
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
|
|
|
|
const std::string& var_name,
|
|
|
|
const std::string& var_name,
|
|
|
|
int64_t time_out) {
|
|
|
|
int64_t time_out) {
|
|
|
|
return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out);
|
|
|
|
return AsyncSendMessage(ep, kSendMonomerFetchBarrierRPC, var_name, time_out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
|
|
|
|
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
|
|
|
|
const platform::DeviceContext& ctx,
|
|
|
|
const platform::DeviceContext& ctx,
|
|
|
|
const framework::Scope& scope,
|
|
|
|
const framework::Scope& scope,
|
|
|
|
const std::string& var_name,
|
|
|
|
const std::string& var_name,
|
|
|
|
|
|
|
|
const std::string& out_var_name,
|
|
|
|
int64_t time_out) {
|
|
|
|
int64_t time_out) {
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out);
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC,
|
|
|
|
|
|
|
|
time_out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
|
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
@ -234,7 +254,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
const auto ch_ptr = GetChannel(ep_val);
|
|
|
|
const auto ch_ptr = GetChannel(ep_val);
|
|
|
|
|
|
|
|
|
|
|
|
const std::string method = "PrefetchRPC";
|
|
|
|
const std::string method = kPrefetchRPC;
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr var_h(
|
|
|
|
VarHandlePtr var_h(
|
|
|
|
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
|
|
|
|
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
|
|
|
@ -270,7 +290,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
|
|
|
|
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
|
|
|
|
int64_t time_out) {
|
|
|
|
int64_t time_out) {
|
|
|
|
return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE,
|
|
|
|
return AsyncSendMessage(ep, kBatchBarrierRPC, BATCH_BARRIER_MESSAGE,
|
|
|
|
time_out);
|
|
|
|
time_out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -286,7 +306,7 @@ VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
req.set_varname(FETCH_BARRIER_MESSAGE);
|
|
|
|
req.set_varname(FETCH_BARRIER_MESSAGE);
|
|
|
|
|
|
|
|
|
|
|
|
const std::string method = "FetchBarrierRPC";
|
|
|
|
const std::string method = kFetchBarrierRPC;
|
|
|
|
// var handle
|
|
|
|
// var handle
|
|
|
|
VarHandlePtr var_h(
|
|
|
|
VarHandlePtr var_h(
|
|
|
|
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
|
|
|
|
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
|
|
|
@ -367,7 +387,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
|
|
|
|
|
|
|
|
|
|
|
|
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
|
|
|
|
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
|
|
|
|
int64_t time_out) {
|
|
|
|
int64_t time_out) {
|
|
|
|
return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out);
|
|
|
|
return AsyncSendMessage(ep, kSendCompleteRPC, COMPLETE_MESSAGE, time_out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void BRPCClient::SendComplete() {
|
|
|
|
void BRPCClient::SendComplete() {
|
|
|
@ -394,9 +414,9 @@ VarHandlePtr BRPCClient::AsyncSendVarMessage(
|
|
|
|
google::protobuf::Closure* done = brpc::NewCallback(
|
|
|
|
google::protobuf::Closure* done = brpc::NewCallback(
|
|
|
|
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
|
|
|
|
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
|
|
|
|
|
|
|
|
|
|
|
|
if (method_name == "CheckPointNotifyRPC") {
|
|
|
|
if (method_name == kCheckPointNotifyRPC) {
|
|
|
|
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
|
|
|
|
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
|
|
|
|
} else if (method_name == "GetMonomerBarrier") {
|
|
|
|
} else if (method_name == kSendMonomerFetchBarrierRPC) {
|
|
|
|
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
|
|
|
|
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
ch_ctx->stub->SendVariable(cntl, &req, response, done);
|
|
|
|
ch_ctx->stub->SendVariable(cntl, &req, response, done);
|
|
|
|