prefetch optimize (#29095)

* test=develop, optimize async prefetch
release/2.0-rc1
123malin 5 years ago committed by GitHub
parent 7c61ba3afb
commit 03d4665f44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -162,6 +162,18 @@ void AsyncCommunicator::SendByCommunicator() {
auto after_send = GetCurrentUS(); auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time " VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge; << after_send - after_merge;
if (var_name.rfind("@GRAD") != var_name.size() - 5) return;
auto recv_param = var_name.substr(0, var_name.size() - 5);
if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end())
return;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_);
auto after_recv = GetCurrentUS();
VLOG(3) << "recv " << recv_param << " use time "
<< after_recv - after_send;
}; };
task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task))); task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
} }

@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_client_threads, 2, "");
DECLARE_bool(rpc_disable_reuse_port); DECLARE_bool(rpc_disable_reuse_port);
namespace paddle { namespace paddle {
@ -32,10 +33,11 @@ namespace distributed {
void GRPCClient::InitImpl() { void GRPCClient::InitImpl() {
// start the client process thread // start the client process thread
// TODO(wuyi): can make this in a threadpool // TODO(wuyi): can make this in a threadpool
PADDLE_ENFORCE_EQ(client_thread_ == nullptr, true, client_threads_.resize(FLAGS_rpc_client_threads);
platform::errors::PreconditionNotMet( for (int i = 0; i < FLAGS_rpc_client_threads; i++) {
"please not re init proceed thread")); client_threads_[i].reset(
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
} }
void GRPCClient::SendComplete() { void GRPCClient::SendComplete() {
@ -62,7 +64,8 @@ GRPCClient::~GRPCClient() {
} }
channels_.clear(); channels_.clear();
} }
client_thread_->join(); for (size_t i = 0; i < client_threads_.size(); i++)
client_threads_[i]->join();
} }
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
@ -84,7 +87,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
@ -206,8 +209,8 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, framework::Async([var_name_val, out_varname_val, table_name_val, s, method,
method, p_ctx, h, rpc_path, this] { p_ctx, h, rpc_path, this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
@ -273,31 +276,29 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, kPrefetchTimeout); s->Prepare(h, kPrefetchTimeout);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, auto* var = p_scope->FindVar(in_var_name_val);
p_ctx, s, method, h, table_name_val, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
out_var_name_val, 0, table_name_val); 0, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context // stub context
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_); &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++; req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
@ -467,7 +468,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
@ -523,8 +524,8 @@ VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
s->Prepare(h, time_out); s->Prepare(h, time_out);
s->RecvPrepare(h_recv); s->RecvPrepare(h_recv);
framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val, framework::Async([send_var_name_val, recv_var_name_val, table_name_val,
p_scope, p_ctx, s, method, h, this] { p_scope, p_ctx, s, method, h, this] {
auto* send_var = p_scope->FindVar(send_var_name_val); auto* send_var = p_scope->FindVar(send_var_name_val);
send_var->GetMutable<framework::LoDTensor>()->set_lod({}); send_var->GetMutable<framework::LoDTensor>()->set_lod({});
::grpc::ByteBuffer buf; ::grpc::ByteBuffer buf;

@ -297,7 +297,7 @@ class GRPCClient : public RPCClient {
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_; std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::unique_ptr<std::thread> client_thread_{nullptr}; std::vector<std::unique_ptr<std::thread>> client_threads_;
// mutex for Wait client sync // mutex for Wait client sync
std::mutex sync_mutex_; std::mutex sync_mutex_;

@ -85,7 +85,7 @@ class RPCServer {
// class, and auto generate a condition id for this call // class, and auto generate a condition id for this call
// to be used for the barrier. // to be used for the barrier.
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler, void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5); int thread_num = 1);
int GetThreadNum(const std::string& rpc_name) { int GetThreadNum(const std::string& rpc_name) {
return rpc_thread_num_[rpc_name]; return rpc_thread_num_[rpc_name];

Loading…
Cancel
Save