From 1239fce771a5d0907045fc285cb1966bdb61b180 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 8 Jun 2018 09:59:54 +0800 Subject: [PATCH] polish sparse update code --- .../fluid/operators/detail/request_handler_impl.cc | 3 +++ paddle/fluid/operators/detail/rpc_server.cc | 13 +++++++++++++ paddle/fluid/operators/detail/rpc_server.h | 6 ++++++ paddle/fluid/operators/listen_and_serv_op.cc | 1 + 4 files changed, 23 insertions(+) diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 145ee53107..b5ee3ab51e 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -63,6 +63,9 @@ bool RequestSendHandler::Handle(const std::string& varname, PADDLE_THROW("sync: Can not find server side var"); return false; } + if (invar->IsType()) { + rpc_server_->RecordSparseVar(invar); + } } return true; diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/detail/rpc_server.cc index 448763372a..7feddbeca8 100644 --- a/paddle/fluid/operators/detail/rpc_server.cc +++ b/paddle/fluid/operators/detail/rpc_server.cc @@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() { t.second = 0; } } +void RPCServer::RecordSparseVar(framework::Variable* sparse_var) { + std::unique_lock lock(mutex_sparse_var_recorder_); + sparse_vars_.push_back(sparse_var); +} + +void RPCServer::ResetSparseVarsRecorder() { + VLOG(3) << "RPCServer reset sparse vars recorder."; + std::unique_lock lock(mutex_sparse_var_recorder_); + for (auto* var : sparse_vars_) { + var->GetMutable()->mutable_rows()->clear(); + } + sparse_vars_.clear(); +} void RPCServer::RegisterRPC(const std::string& rpc_name, RequestHandler* handler, int thread_num) { diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/detail/rpc_server.h index c2e7ae706c..94a21ef8d0 100644 --- a/paddle/fluid/operators/detail/rpc_server.h +++ b/paddle/fluid/operators/detail/rpc_server.h @@ -60,7 +60,10 @@ class RPCServer { void SetCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); + void ResetBarrierCounter(); + void RecordSparseVar(framework::Variable* sparse_var); + void ResetSparseVarsRecorder(); protected: virtual void ShutDownImpl() = 0; @@ -74,6 +77,9 @@ class RPCServer { std::atomic cur_cond_; std::condition_variable rpc_cond_; + std::vector sparse_vars_; + std::mutex mutex_sparse_var_recorder_; + protected: std::string bind_address_; std::atomic exit_flag_; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 0c9d2b5a74..ee7b01a54c 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -146,6 +146,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, rpc_service_->SetCond(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->ResetBarrierCounter(); + rpc_service_->ResetSparseVarsRecorder(); } // while(true) }