|
|
|
@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
|
|
|
|
|
const framework::Scope& scope,
|
|
|
|
|
const std::string& var_name,
|
|
|
|
|
const std::string& out_varname,
|
|
|
|
|
const std::string& table_name,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
|
|
|
|
|
"/sendrecv.SendRecvService/GetVariable", time_out);
|
|
|
|
|
"/sendrecv.SendRecvService/GetVariable", table_name,
|
|
|
|
|
time_out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
|
|
|
|
@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
|
|
|
|
|
|
|
|
|
|
return _AsyncGetVar(
|
|
|
|
|
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
|
|
|
|
|
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
|
|
|
|
|
"/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
|
|
|
|
@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
|
|
|
|
|
const framework::Scope& scope, const std::string& var_name,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
|
|
|
|
|
"/sendrecv.SendRecvService/GetMonomerVariable", time_out);
|
|
|
|
|
"/sendrecv.SendRecvService/GetMonomerVariable", "",
|
|
|
|
|
time_out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandlePtr GRPCClient::_AsyncGetVar(
|
|
|
|
|
const std::string& ep, const platform::DeviceContext& ctx,
|
|
|
|
|
const framework::Scope& scope, const std::string& method,
|
|
|
|
|
const std::string& var_name, const std::string& out_varname,
|
|
|
|
|
const std::string& rpc_path, int64_t time_out) {
|
|
|
|
|
const std::string& rpc_path, const std::string& table_name,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
const platform::DeviceContext* p_ctx = &ctx;
|
|
|
|
|
const std::string ep_val = ep;
|
|
|
|
|
const std::string var_name_val = var_name;
|
|
|
|
|
const std::string out_varname_val = out_varname;
|
|
|
|
|
const std::string table_name_val = table_name;
|
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
|
const auto ch = GetChannel(ep_val);
|
|
|
|
|
GetProcessor* s = new GetProcessor(ch);
|
|
|
|
@ -169,32 +174,33 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO(
|
|
|
|
|
[var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
|
|
|
|
|
// prepare input
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
|
req.set_out_varname(out_varname_val);
|
|
|
|
|
req.set_trainer_id(trainer_id_);
|
|
|
|
|
::grpc::ByteBuffer buf;
|
|
|
|
|
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
|
|
|
|
|
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method,
|
|
|
|
|
p_ctx, h, rpc_path, this] {
|
|
|
|
|
// prepare input
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
|
req.set_out_varname(out_varname_val);
|
|
|
|
|
req.set_trainer_id(trainer_id_);
|
|
|
|
|
req.set_table_name(table_name_val);
|
|
|
|
|
::grpc::ByteBuffer buf;
|
|
|
|
|
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
|
|
|
|
|
|
|
|
|
|
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
|
|
|
|
|
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
|
|
|
|
|
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
|
|
|
|
|
platform::RecordRPCEvent record_event(method);
|
|
|
|
|
platform::RecordRPCEvent record_event(method);
|
|
|
|
|
|
|
|
|
|
auto call =
|
|
|
|
|
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
auto call =
|
|
|
|
|
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
|