|
|
|
@ -73,10 +73,11 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
|
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
|
const auto ch = GetChannel(ep_val);
|
|
|
|
|
SendProcessor* s = new SendProcessor(ch);
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
|
|
|
|
|
const std::string method = "SendRPC";
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
|
|
|
|
|
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
|
|
|
|
|
auto* var = p_scope->FindVar(var_name_val);
|
|
|
|
|
|
|
|
|
|
::grpc::ByteBuffer req;
|
|
|
|
@ -87,10 +88,16 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = nullptr;
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(method, p_ctx);
|
|
|
|
|
|
|
|
|
|
auto call = s->stub_g_.PrepareUnaryCall(
|
|
|
|
|
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
@ -122,10 +129,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
|
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
|
const auto ch = GetChannel(ep_val);
|
|
|
|
|
GetProcessor* s = new GetProcessor(ch);
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
|
|
|
|
|
const std::string method = "GetRPC";
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([var_name_val, s, this] {
|
|
|
|
|
framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
|
|
|
|
|
// prepare input
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
@ -137,10 +145,16 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(method, p_ctx);
|
|
|
|
|
|
|
|
|
|
auto call = s->stub_g_.PrepareUnaryCall(
|
|
|
|
|
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
req_count_++;
|
|
|
|
@ -161,12 +175,14 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
|
|
const framework::Scope* p_scope = &scope;
|
|
|
|
|
const auto ch = GetChannel(ep_val);
|
|
|
|
|
GetProcessor* s = new GetProcessor(ch);
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
|
|
|
|
|
|
|
|
|
|
const std::string method = "PrefetchRPC";
|
|
|
|
|
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
|
|
|
|
|
s, this] {
|
|
|
|
|
s, method, h, this] {
|
|
|
|
|
auto* var = p_scope->FindVar(in_var_name_val);
|
|
|
|
|
|
|
|
|
|
::grpc::ByteBuffer req;
|
|
|
|
@ -177,11 +193,17 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(method, p_ctx);
|
|
|
|
|
|
|
|
|
|
auto call = s->stub_g_.PrepareUnaryCall(
|
|
|
|
|
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
|
|
|
|
|
&cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
req_count_++;
|
|
|
|
@ -193,15 +215,24 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
|
|
|
|
|
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
|
|
|
|
|
nullptr, nullptr));
|
|
|
|
|
const std::string method = "BatchBarrierRPC";
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(BATCH_BARRIER_MESSAGE);
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(method, nullptr);
|
|
|
|
|
|
|
|
|
|
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
|
|
|
|
|
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return h;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -209,15 +240,24 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
|
|
|
|
|
int64_t time_out) {
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
|
|
|
|
|
nullptr, nullptr));
|
|
|
|
|
const std::string method = "FetchBarrierRPC";
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(FETCH_BARRIER_MESSAGE);
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(method, nullptr);
|
|
|
|
|
|
|
|
|
|
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
|
|
|
|
|
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return h;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -226,15 +266,23 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
|
|
|
|
|
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
|
|
|
|
|
const std::string method = "SendCompleteRPC";
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(COMPLETE_MESSAGE);
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(method, nullptr);
|
|
|
|
|
|
|
|
|
|
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
|
|
|
|
|
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return h;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -244,17 +292,27 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
|
|
|
|
|
const auto ch = GetChannel(ep);
|
|
|
|
|
|
|
|
|
|
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
|
|
|
|
|
nullptr, nullptr));
|
|
|
|
|
|
|
|
|
|
const std::string method = "CheckPointNotifyRPC";
|
|
|
|
|
|
|
|
|
|
VarHandlePtr h(
|
|
|
|
|
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
|
|
|
|
|
req.set_out_varname(dir);
|
|
|
|
|
|
|
|
|
|
platform::RecordEvent record_event(method, nullptr);
|
|
|
|
|
|
|
|
|
|
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
|
|
|
|
|
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return h;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -273,6 +331,7 @@ void GRPCClient::Proceed() {
|
|
|
|
|
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
|
|
|
|
|
GPR_ASSERT(ok);
|
|
|
|
|
PADDLE_ENFORCE(c);
|
|
|
|
|
|
|
|
|
|
if (c->status_.ok()) {
|
|
|
|
|
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
|
|
|
|
|
c->Process();
|
|
|
|
|