|
|
|
@ -23,6 +23,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/platform/port.h"
|
|
|
|
|
#include "paddle/fluid/platform/profiler.h"
|
|
|
|
|
|
|
|
|
|
DEFINE_int32(rpc_client_threads, 2, "");
|
|
|
|
|
DECLARE_bool(rpc_disable_reuse_port);
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -32,10 +33,11 @@ namespace distributed {
|
|
|
|
|
void GRPCClient::InitImpl() {
|
|
|
|
|
// start the client process thread
|
|
|
|
|
// TODO(wuyi): can make this in a threadpool
|
|
|
|
|
PADDLE_ENFORCE_EQ(client_thread_ == nullptr, true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"please not re init proceed thread"));
|
|
|
|
|
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
|
|
|
|
|
client_threads_.resize(FLAGS_rpc_client_threads);
|
|
|
|
|
for (int i = 0; i < FLAGS_rpc_client_threads; i++) {
|
|
|
|
|
client_threads_[i].reset(
|
|
|
|
|
new std::thread(std::bind(&GRPCClient::Proceed, this)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GRPCClient::SendComplete() {
|
|
|
|
@ -62,7 +64,8 @@ GRPCClient::~GRPCClient() {
|
|
|
|
|
}
|
|
|
|
|
channels_.clear();
|
|
|
|
|
}
|
|
|
|
|
client_thread_->join();
|
|
|
|
|
for (size_t i = 0; i < client_threads_.size(); i++)
|
|
|
|
|
client_threads_[i]->join();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
|
|
|
|
@ -84,7 +87,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
|
|
|
|
|
framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
|
|
|
|
|
auto* var = p_scope->FindVar(var_name_val);
|
|
|
|
|
|
|
|
|
|
::grpc::ByteBuffer req;
|
|
|
|
@ -206,8 +209,8 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s,
|
|
|
|
|
method, p_ctx, h, rpc_path, this] {
|
|
|
|
|
framework::Async([var_name_val, out_varname_val, table_name_val, s, method,
|
|
|
|
|
p_ctx, h, rpc_path, this] {
|
|
|
|
|
// prepare input
|
|
|
|
|
sendrecv::VariableMessage req;
|
|
|
|
|
req.set_varname(var_name_val);
|
|
|
|
@ -273,31 +276,29 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, kPrefetchTimeout);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope,
|
|
|
|
|
p_ctx, s, method, h, table_name_val, this] {
|
|
|
|
|
auto* var = p_scope->FindVar(in_var_name_val);
|
|
|
|
|
auto* var = p_scope->FindVar(in_var_name_val);
|
|
|
|
|
|
|
|
|
|
::grpc::ByteBuffer req;
|
|
|
|
|
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req,
|
|
|
|
|
out_var_name_val, 0, table_name_val);
|
|
|
|
|
::grpc::ByteBuffer req;
|
|
|
|
|
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
|
|
|
|
|
0, table_name_val);
|
|
|
|
|
|
|
|
|
|
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
|
|
|
|
|
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
|
|
|
|
|
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
// stub context
|
|
|
|
|
s->response_call_back_ = ProcGetResponse;
|
|
|
|
|
|
|
|
|
|
platform::RecordRPCEvent record_event(method);
|
|
|
|
|
platform::RecordRPCEvent record_event(method);
|
|
|
|
|
|
|
|
|
|
auto call = s->stub_g_.PrepareUnaryCall(
|
|
|
|
|
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
|
|
|
|
|
&cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
|
|
|
|
|
auto call = s->stub_g_.PrepareUnaryCall(
|
|
|
|
|
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
|
|
|
|
|
&cq_);
|
|
|
|
|
call->StartCall();
|
|
|
|
|
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(platform::IsProfileEnabled())) {
|
|
|
|
|
h->Wait();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
req_count_++;
|
|
|
|
|
|
|
|
|
|
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
|
|
|
|
@ -467,7 +468,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
|
|
|
|
|
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
|
|
|
|
|
framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
|
|
|
|
|
auto* var = p_scope->FindVar(var_name_val);
|
|
|
|
|
|
|
|
|
|
::grpc::ByteBuffer req;
|
|
|
|
@ -523,8 +524,8 @@ VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
|
|
|
|
|
s->Prepare(h, time_out);
|
|
|
|
|
s->RecvPrepare(h_recv);
|
|
|
|
|
|
|
|
|
|
framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val,
|
|
|
|
|
p_scope, p_ctx, s, method, h, this] {
|
|
|
|
|
framework::Async([send_var_name_val, recv_var_name_val, table_name_val,
|
|
|
|
|
p_scope, p_ctx, s, method, h, this] {
|
|
|
|
|
auto* send_var = p_scope->FindVar(send_var_name_val);
|
|
|
|
|
send_var->GetMutable<framework::LoDTensor>()->set_lod({});
|
|
|
|
|
::grpc::ByteBuffer buf;
|
|
|
|
|