From 7570d8e77cadf89760187a787b48693608cc8aaf Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 18 Jul 2018 16:22:19 +0800 Subject: [PATCH 1/6] add rpc complete interface --- paddle/fluid/framework/executor.cc | 10 ++---- paddle/fluid/framework/executor.h | 8 ++--- .../operators/distributed/CMakeLists.txt | 2 +- .../operators/distributed/grpc_client.cc | 31 +++---------------- .../fluid/operators/distributed/grpc_client.h | 11 ++----- .../operators/distributed/request_handler.h | 2 -- .../distributed/request_handler_impl.cc | 11 +++---- .../fluid/operators/distributed/rpc_client.h | 14 +++------ .../fluid/operators/distributed/rpc_server.cc | 22 ++++++------- .../fluid/operators/distributed/rpc_server.h | 5 ++- .../operators/distributed/rpc_server_test.cc | 29 ++++++++++++++--- paddle/fluid/pybind/pybind.cc | 3 +- python/paddle/fluid/executor.py | 7 ++--- 13 files changed, 62 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 84f67fafa1..750a08d3a3 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -46,16 +46,10 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} #ifdef PADDLE_WITH_DISTRIBUTE -void Executor::BeginPass() { +void Executor::Complete() { ::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::GRPCClient>() - ->SendBeginPass(); -} - -void Executor::EndPass() { - ::paddle::operators::distributed::RPCClient::GetInstance< - ::paddle::operators::distributed::GRPCClient>() - ->SendEndPass(); + ->SendComplete(); } #endif diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 563a4b2bb6..53ebf18bd2 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -46,14 +46,10 @@ class Executor { #ifdef PADDLE_WITH_DISTRIBUTE /* - * Sending signal to pserver to mark current pass started. + * Sending signal to pserver to mark current trainer completed. */ - void BeginPass(); + void Complete(); - /* - * Sending signal to pserver to mark current pass finished. - */ - void EndPass(); #endif /* @Brief diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 675ca36774..a6b1c43ce9 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -6,7 +6,7 @@ if(WITH_GRPC) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) - cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc + cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL) return() diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 4d60801b6a..7ef7482cab 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -35,18 +35,10 @@ void GRPCClient::InitEventLoop() { client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } -void GRPCClient::SendBeginPass() { +void GRPCClient::SendComplete() { for (auto& it : channels_) { - VLOG(3) << "send begin pass to: " << it.first; - this->AsyncSendBeginPass(it.first); - } - this->Wait(); -} - -void GRPCClient::SendEndPass() { - for (auto& it : channels_) { - VLOG(3) << "send end pass to " << it.first; - this->AsyncSendEndPass(it.first); + VLOG(3) << "send complete message to " << it.first; + this->AsyncSendComplete(it.first); } this->Wait(); } @@ -238,32 +230,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, req_count_++; } -void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); sendrecv::VariableMessage req; - req.set_varname(BEGIN_PASS_MESSAGE); + req.set_varname(COMPLETE_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } -void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); - - FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - s->Prepare(time_out); - - sendrecv::VariableMessage req; - req.set_varname(END_PASS_MESSAGE); - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - req_count_++; -} - void GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out) { diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index d03a3e56ae..26cad7548e 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -215,17 +215,12 @@ class GRPCClient : public RPCClient { void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) override; - void AsyncSendBeginPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendEndPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; + void AsyncSendComplete(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; - void SendBeginPass() override; - - void SendEndPass() override; + void SendComplete() override; protected: void InitImpl() override; diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 271306d5d2..f68f9d8f3c 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" -#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV" -#define END_PASS_MESSAGE "END_PASS@RECV" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 5e6bff20f5..d9a1ac583f 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname, if (varname == BATCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv batch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestSend); - } else if (varname == BEGIN_PASS_MESSAGE) { - VLOG(3) << "sync: recv begin pass message"; - rpc_server_->WaitCond(kRequestSend); - rpc_server_->BeginPass(); + } else if (varname == COMPLETE_MESSAGE) { + VLOG(3) << "sync: recv complete message"; + rpc_server_->Complete(); } else { VLOG(3) << "sync: received var_name: " << varname; rpc_server_->WaitCond(kRequestSend); @@ -95,14 +94,12 @@ bool RequestGetHandler::Handle(const std::string& varname, if (varname == FETCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv fetch barrier message"; rpc_server_->IncreaseBatchBarrier(kRequestGet); - } else if (varname == END_PASS_MESSAGE) { - rpc_server_->EndPass(); } else { rpc_server_->WaitCond(kRequestGet); *outvar = scope_->FindVar(varname); } } else { - if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) { + if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { *outvar = scope_->FindVar(varname); } } diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 4d87376fbf..22a022a5d2 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -60,17 +60,13 @@ class RPCClient { const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) = 0; - virtual void AsyncSendBeginPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual void AsyncSendComplete(const std::string& ep, + int64_t time_out = FLAGS_rpc_deadline) = 0; - virtual void AsyncSendEndPass(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - // BeginePass/EndPass tells all the pserver that start/end a pass, so that - // the pserver can increase/reduce it's barrier count, and continue to train + // Complete tells all the pserver instances that finishe the training, + // the pserver can reduce it's barrier count, and continue to train // with other trainers. - virtual void SendBeginPass() = 0; - virtual void SendEndPass() = 0; + virtual void SendComplete() = 0; virtual bool Wait() = 0; diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index d49ee34eea..42dc7bd10b 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { } } -void RPCServer::BeginPass() { - VLOG(4) << "RPCServer begin increase pass barrier"; - { - std::unique_lock lock(mutex_); - client_num_++; - VLOG(4) << "increase client_num to: " << client_num_; - } - barrier_cond_.notify_all(); -} - -void RPCServer::EndPass() { - VLOG(4) << "RPCServer begin increase pass barrier"; +void RPCServer::Complete() { { std::unique_lock lock(mutex_); client_num_--; @@ -83,10 +72,19 @@ void RPCServer::EndPass() { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { barrier_counter_[kRequestGet]--; } + if (client_num_ == 0) { + exit_flag_ = true; + VLOG(1) << "No activate Trainer instance, PServer will exit..."; + } } barrier_cond_.notify_all(); } +int RPCServer::GetClientNum() { + std::unique_lock lock(mutex_); + return client_num_; +} + void RPCServer::ResetBarrierCounter() { VLOG(3) << "RPCServer ResetBarrierCounter "; std::unique_lock lock(mutex_); diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index 833991c8aa..fd914d7a72 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -44,7 +44,7 @@ class RPCServer { int GetSelectedPort() const { return selected_port_; } - int GetClientNum() const; + int GetClientNum(); void SavePort() const; @@ -64,8 +64,7 @@ class RPCServer { void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); - void BeginPass(); - void EndPass(); + void Complete(); void ResetBarrierCounter(); diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index a0693cffab..9f2360ec70 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, } } -void StartServer() { +void StartServer(const std::string& rpc_name) { framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; @@ -107,14 +107,14 @@ void StartServer() { std::shared_ptr> prefetch_var_name_to_prepared; prefetch_var_name_to_prepared[in_var_name] = prepared[0]; + g_req_handler->SetProgram(&program); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); g_req_handler->SetDevCtx(&ctx); g_req_handler->SetScope(&scope); g_req_handler->SetExecutor(&exe); - g_rpc_service->RegisterRPC(distributed::kRequestPrefetch, - g_req_handler.get()); + g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( @@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) { distributed::RPCClient* client = distributed::RPCClient::GetInstance(); - std::thread server_thread(StartServer); + std::thread server_thread(StartServer, distributed::kRequestPrefetch); g_rpc_service->WaitServerReady(); int port = g_rpc_service->GetSelectedPort(); @@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) { g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); } + +TEST(COMPLETE, CPU) { + g_req_handler.reset(new distributed::RequestSendHandler(true)); + g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(); + PADDLE_ENFORCE(client != nullptr); + std::thread server_thread(StartServer, distributed::kRequestSend); + g_rpc_service->WaitServerReady(); + int port = g_rpc_service->GetSelectedPort(); + std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); + client->AsyncSendComplete(ep); + client->Wait(); + + EXPECT_EQ(g_rpc_service->GetClientNum(), 1); + + g_rpc_service->ShutDown(); + server_thread.join(); + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); +} diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 216c4666c0..9669e4c083 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -503,8 +503,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) #ifdef PADDLE_WITH_DISTRIBUTE - .def("begin_pass", &Executor::BeginPass) - .def("end_pass", &Executor::EndPass) + .def("complete", &Executor::Complete) #endif .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars) { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index f9e600cb4c..d5f11619a3 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -348,11 +348,8 @@ class Executor(object): ] return outs - def begin_pass(self): - self.executor.begin_pass() - - def end_pass(self): - self.executor.end_pass() + def complete(self): + self.executor.complete() def run(self, program=None, From d0771cf912143e486edc94f6779c09bfba1537e6 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 18 Jul 2018 16:54:41 +0800 Subject: [PATCH 2/6] update --- paddle/fluid/operators/distributed/rpc_server.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index 42dc7bd10b..83b14fa64d 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -72,10 +72,6 @@ void RPCServer::Complete() { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { barrier_counter_[kRequestGet]--; } - if (client_num_ == 0) { - exit_flag_ = true; - VLOG(1) << "No activate Trainer instance, PServer will exit..."; - } } barrier_cond_.notify_all(); } From 800702ccee6b81d033e003e13d2c3d2e26b067e2 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 18 Jul 2018 20:26:36 +0800 Subject: [PATCH 3/6] fix APIspec --- paddle/fluid/API.spec | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 5c17ec4fc3..b2feffc630 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -42,8 +42,7 @@ paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', def paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Executor.begin_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Executor.end_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) +paddle.fluid.Executor.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)) paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None) paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) From efd5a849867125105cbde1682a7e2f20034ff336 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 19 Jul 2018 14:39:42 +0800 Subject: [PATCH 4/6] update executor interface --- paddle/fluid/framework/executor.cc | 4 ++-- paddle/fluid/framework/executor.h | 2 +- .../operators/distributed/grpc_client.cc | 12 ++++++---- .../fluid/operators/distributed/grpc_client.h | 6 ++++- paddle/fluid/pybind/pybind.cc | 4 +--- python/paddle/fluid/executor.py | 24 +++++++++++++++++-- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 750a08d3a3..c2800c972a 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -45,13 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { Executor::Executor(const platform::Place& place) : place_(place) {} +void Executor::Close() { #ifdef PADDLE_WITH_DISTRIBUTE -void Executor::Complete() { ::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::GRPCClient>() ->SendComplete(); -} #endif +} void InitializeVariable(Variable* var, proto::VarType::Type var_type) { if (var_type == proto::VarType::LOD_TENSOR) { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 53ebf18bd2..3e78572334 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -48,7 +48,7 @@ class Executor { /* * Sending signal to pserver to mark current trainer completed. */ - void Complete(); + void Close(); #endif diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 7ef7482cab..092e9dfd74 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -36,11 +36,15 @@ void GRPCClient::InitEventLoop() { } void GRPCClient::SendComplete() { - for (auto& it : channels_) { - VLOG(3) << "send complete message to " << it.first; - this->AsyncSendComplete(it.first); + std::unique_lock lk(completed_mutex_); + if (!completed_) { + for (auto& it : channels_) { + VLOG(3) << "send complete message to " << it.first; + this->AsyncSendComplete(it.first); + } + PADDLE_ENFORCE(this->Wait(), "internal grpc error"); + completed_ = true; } - this->Wait(); } GRPCClient::~GRPCClient() { diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 26cad7548e..373183be55 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { class GRPCClient : public RPCClient { public: - GRPCClient() : ok_(true) {} + GRPCClient() : ok_(true), completed_(false) {} virtual ~GRPCClient(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, @@ -247,6 +247,10 @@ class GRPCClient : public RPCClient { // mutex for GetChannel thread safety std::mutex chan_mutex_; DISABLE_COPY_AND_ASSIGN(GRPCClient); + + // mutex for sending complete message only once + std::mutex completed_mutex_; + bool completed_; }; } // namespace distributed diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9669e4c083..c681eb9d91 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -502,9 +502,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Executor") .def(py::init()) -#ifdef PADDLE_WITH_DISTRIBUTE - .def("complete", &Executor::Complete) -#endif + .def("close", &Executor::Close) .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars) { pybind11::gil_scoped_release release; diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d5f11619a3..4178971398 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -247,6 +247,7 @@ class Executor(object): p.set_place(place) self.executor = core.Executor(p) self.program_caches = dict() + self._closed = False def as_lodtensor(self, data): """ @@ -348,8 +349,23 @@ class Executor(object): ] return outs - def complete(self): - self.executor.complete() + def close(self): + """ + Close this executor. + + You can no long use this executor after calling this method. + For the distributed training, this method would free the resource on PServers related to + the current Trainer. + + Example: + >>> cpu = core.CPUPlace() + >>> exe = Executor(cpu) + >>> ... + >>> exe.close() + """ + if not self._closed: + self.executor.close() + self._closed = True def run(self, program=None, @@ -402,6 +418,10 @@ class Executor(object): >>> feed={'X': x}, >>> fetch_list=[loss.name]) """ + + if self._closed: + raise RuntimeError("Attempted to use a closed Executor") + if feed is None: feed = {} if not isinstance(feed, dict): From 9be4ee9221a9c9a9a7e6c96c6b6481398a1302b2 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 19 Jul 2018 14:40:18 +0800 Subject: [PATCH 5/6] update api spec --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b2feffc630..d25be7ddbb 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -42,7 +42,7 @@ paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', def paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Executor.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) +paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)) paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None) paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) From be772741cf2d5db7fd8c0f3cbfe1a1c9a6c6fae9 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 19 Jul 2018 16:44:26 +0800 Subject: [PATCH 6/6] compile with cpu --- paddle/fluid/framework/executor.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 3e78572334..214ca3dc49 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -44,14 +44,12 @@ class Executor { explicit Executor(const platform::Place& place); -#ifdef PADDLE_WITH_DISTRIBUTE /* - * Sending signal to pserver to mark current trainer completed. + * Close this Executor. + * Calling this method will send complete messages to all pserver instances. */ void Close(); -#endif - /* @Brief * Runtime evaluation of the given ProgramDesc under certain Scope *