From b4dd4c048d1d121109f9f7f03c91113e02b4f5d0 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 21 May 2018 21:59:52 -0700 Subject: [PATCH 1/6] multi-thread handlerequest Experiment on vgg flower, 2 trainers, 1ps. more trainer could have more speedup. After: Pass = 0, Iters = 327, Speed = (7.52) img/s Before: Pass = 0, Iters = 385, Speed = (6.77) img/s --- benchmark/cluster/vgg16/vgg16_fluid.py | 26 +-- cmake/external/grpc.cmake | 2 +- paddle/fluid/framework/executor.cc | 5 +- paddle/fluid/operators/detail/grpc_client.cc | 8 +- paddle/fluid/operators/detail/grpc_server.cc | 154 ++++++++++++------ paddle/fluid/operators/detail/grpc_server.h | 21 ++- paddle/fluid/operators/detail/grpc_service.h | 2 + paddle/fluid/operators/detail/send_recv.proto | 2 +- .../operators/detail/sendrecvop_utils.cc | 8 +- .../operators/detail/variable_response.cc | 8 +- paddle/fluid/platform/device_tracer.cc | 1 - 11 files changed, 158 insertions(+), 79 deletions(-) diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 05b5f3977c..0f5cd2a253 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -38,7 +38,7 @@ def str2bool(v): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--batch_size', type=int, default=128, help="Batch size for training.") + '--batch_size', type=int, default=16, help="Batch size for training.") parser.add_argument( '--learning_rate', type=float, @@ -61,7 +61,7 @@ parser.add_argument( parser.add_argument( '--data_set', type=str, - default='cifar10', + default='flowers', choices=['cifar10', 'flowers'], help='Optional dataset for benchmark.') parser.add_argument( @@ -200,26 +200,30 @@ def main(): fetch_list=[avg_cost, batch_acc, batch_size]) return loss, acc, b_size - if args.profile and args.task_index == 0: - # warmup. - for batch_id, data in enumerate(train_reader()): - if batch_id > 5: break - run_step(batch_id, data) - with profiler.profiler('All', 'total', '/tmp/profile_vgg'): + if args.profile: + with profiler.profiler('All', 'total', + '/tmp/profile_vgg_%d' % args.task_index): for batch_id, data in enumerate(train_reader()): - if batch_id > 5: break + if batch_id > 4: break run_step(batch_id, data) + total_time = 0.0 + count = 0 for batch_id, data in enumerate(train_reader()): ts = time.time() loss, acc, b_size = run_step(batch_id, data) iters += 1 num_samples += len(data) train_pass_acc.add(value=acc, weight=b_size) + + duration = time.time() - ts + total_time += duration + count += len(data) print( "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " - "Speed = %.2f img/s" % (pass_id, iters, loss, acc, - len(data) / (time.time() - ts)) + "Speed = %.2f (%.2f) img/s" % (pass_id, iters, loss, acc, + len(data) / duration, + count / total_time) ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index e90948782b..ef520b1287 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -33,7 +33,7 @@ ExternalProject_Add( extern_grpc DEPENDS protobuf zlib GIT_REPOSITORY "https://github.com/grpc/grpc.git" - GIT_TAG "v1.10.x" + GIT_TAG "v1.8.x" PREFIX ${GRPC_SOURCES_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 4e431561f8..55be9b6c3b 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -350,12 +350,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } } - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + // platform::DeviceContextPool::Instance().Get(place_)->Wait(); if (create_vars && create_local_scope) { scope->DeleteScope(local_scope); - } else { - // Delete the local scopes created in operators. - scope->DropKids(); } if (FLAGS_benchmark) { VLOG(2) << "-------------------------------------------------------"; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ae60ab1532..47892b1bcc 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -196,9 +197,14 @@ bool RPCClient::Wait() { const size_t kReqCnt = req_count_; bool a[kReqCnt]; std::vector> waits(req_count_); + std::mutex mu; for (int i = 0; i < req_count_; i++) { - waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, &mu, this] { + bool ret = Proceed(); + std::lock_guard l(mu); + a[i] = ret; + }); } for (int i = 0; i < req_count_; i++) { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index eb114a47d9..604321cd1f 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -22,7 +22,10 @@ using ::grpc::ServerAsyncResponseWriter; namespace paddle { namespace operators { namespace detail { - +namespace { +const int kNumHandleSendThreads = 20; +const int kNumHandleGetThreads = 20; +} // namespace enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -63,18 +66,20 @@ class RequestSend final : public RequestBase { explicit RequestSend(GrpcService::AsyncService* service, ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, - const platform::DeviceContext* dev_ctx) + const platform::DeviceContext* dev_ctx, int i) : RequestBase(service, cq, sync_mode, dev_ctx), queue_(queue), - responder_(&ctx_) { + responder_(&ctx_), + i_(i) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kSendVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(i))); } virtual ~RequestSend() {} @@ -86,15 +91,17 @@ class RequestSend final : public RequestBase { VLOG(3) << "RequestSend " << var_name; queue_->Push(std::make_pair(var_name, request_)); - sendrecv::VoidMessage reply; - responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(i_))); } protected: + sendrecv::VoidMessage reply_; std::shared_ptr request_; ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; + int i_; }; class RequestGet final : public RequestBase { @@ -103,14 +110,16 @@ class RequestGet final : public RequestBase { ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - framework::BlockingQueue* queue) + framework::BlockingQueue* queue, int i) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), - queue_(queue) { + queue_(queue), + i_(i) { auto method_id = static_cast(detail::GrpcMethod::kGetVariable); - service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, - cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, &request_, &responder_, cq_, cq_, + reinterpret_cast(static_cast(i))); } virtual ~RequestGet() {} @@ -123,13 +132,13 @@ class RequestGet final : public RequestBase { VLOG(3) << "RequestGet " << var_name; auto* var = scope_->FindVar(var_name); - ::grpc::ByteBuffer reply; if (var_name != FETCH_BARRIER_MESSAGE) { - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); } - responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(i_))); if (var_name == FETCH_BARRIER_MESSAGE) { sendrecv::VariableMessage msg; @@ -140,9 +149,11 @@ class RequestGet final : public RequestBase { protected: sendrecv::VariableMessage request_; + ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::BlockingQueue* queue_; + int i_; }; class RequestPrefetch final : public RequestBase { @@ -153,13 +164,15 @@ class RequestPrefetch final : public RequestBase { const platform::DeviceContext* dev_ctx, framework::Executor* executor, framework::ProgramDesc* program, - framework::ExecutorPrepareContext* prefetch_ctx) + framework::ExecutorPrepareContext* prefetch_ctx, + int i) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), - prefetch_ctx_(prefetch_ctx) { + prefetch_ctx_(prefetch_ctx), + i_(i) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { @@ -188,7 +201,8 @@ class RequestPrefetch final : public RequestBase { SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); - responder_.Finish(reply, ::grpc::Status::OK, this); + responder_.Finish(reply, ::grpc::Status::OK, + reinterpret_cast(static_cast(i_))); status_ = FINISH; } @@ -199,6 +213,7 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; + int i_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -232,20 +247,33 @@ void AsyncGRPCServer::RunSyncUpdate() { LOG(INFO) << "Server listening on " << address_ << " selected port: " << selected_port_; - std::function send_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); - std::function get_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); - std::function prefetch_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); + std::function send_register = std::bind( + &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1); + std::function get_register = std::bind( + &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1); + std::function prefetch_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this, + std::placeholders::_1); + + for (int i = 0; i < kSendReqsBufSize; ++i) { + TryToRegisterNewSendOne(i); + } + for (int i = 0; i < kGetReqsBufSize; ++i) { + TryToRegisterNewGetOne(i); + } + + for (int i = 0; i < kNumHandleSendThreads; ++i) { + t_sends_.emplace_back( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, + cq_send_.get(), "cq_send", send_register))); + } + for (int i = 0; i < kNumHandleGetThreads; ++i) { + t_gets_.emplace_back( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, + cq_get_.get(), "cq_get", get_register))); + } // TODO(wuyi): Run these "HandleRequest" in thread pool - t_send_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_send_.get(), "cq_send", send_register))); - t_get_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_get_.get(), "cq_get", get_register))); t_prefetch_.reset(new std::thread( std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), "cq_prefetch", prefetch_register))); @@ -257,8 +285,27 @@ void AsyncGRPCServer::RunSyncUpdate() { condition_ready_.notify_all(); // wait server server_->Wait(); - t_send_->join(); - t_get_->join(); + for (int i = 0; i < kNumHandleSendThreads; ++i) { + t_sends_[i]->join(); + } + for (int i = 0; i < kNumHandleGetThreads; ++i) { + t_gets_[i]->join(); + } + { + std::lock_guard l(cq_mutex_); + for (int i = 0; i < kSendReqsBufSize; ++i) { + if (send_reqs_[i]) { + delete send_reqs_[i]; + send_reqs_[i] = nullptr; + } + } + for (int i = 0; i < kGetReqsBufSize; ++i) { + if (get_reqs_[i]) { + delete get_reqs_[i]; + get_reqs_[i] = nullptr; + } + } + } t_prefetch_->join(); } @@ -276,47 +323,47 @@ void AsyncGRPCServer::ShutDown() { server_->Shutdown(); } -void AsyncGRPCServer::TryToRegisterNewSendOne() { +void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, - scope_, &var_recv_queue_, dev_ctx_); + scope_, &var_recv_queue_, dev_ctx_, i); + send_reqs_[i] = static_cast(send); VLOG(4) << "Create RequestSend status:" << send->Status(); } -void AsyncGRPCServer::TryToRegisterNewGetOne() { +void AsyncGRPCServer::TryToRegisterNewGetOne(int i) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, - dev_ctx_, &var_get_queue_); + dev_ctx_, &var_get_queue_, i); + get_reqs_[i] = static_cast(get); VLOG(4) << "Create RequestGet status:" << get->Status(); } -void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { +void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; return; } - RequestPrefetch* prefetch = - new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, - dev_ctx_, executor_, program_, prefetch_ctx_.get()); + RequestPrefetch* prefetch = new RequestPrefetch( + &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, + program_, prefetch_ctx_.get(), i); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } // FIXME(typhoonzero): change cq_name to enum. -void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, - const std::string& cq_name, - std::function TryToRegisterNewOne) { - TryToRegisterNewOne(); - +void AsyncGRPCServer::HandleRequest( + ::grpc::ServerCompletionQueue* cq, const std::string& cq_name, + std::function TryToRegisterNewOne) { void* tag = NULL; bool ok = false; @@ -327,8 +374,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, break; } VLOG(3) << "HandleRequest for " << cq_name << " get Next"; - - PADDLE_ENFORCE(tag); + int i = static_cast(reinterpret_cast(tag)); if (sync_mode_) { // FIXME(typhoonzero): de-couple the barriers with recv_op @@ -337,7 +383,17 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond"; } - RequestBase* base = reinterpret_cast(tag); + RequestBase* base = nullptr; + { + std::lock_guard l(cq_mutex_); + if (cq_name == "cq_get") { + base = get_reqs_[i]; + } else if (cq_name == "cq_send") { + base = send_reqs_[i]; + } else { + CHECK(false); + } + } // reference: // https://github.com/tensorflow/tensorflow/issues/5596 // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM @@ -345,19 +401,19 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, if (!ok) { LOG(WARNING) << cq_name << " recv no regular event:argument name[" << base->GetReqName() << "]"; - TryToRegisterNewOne(); + TryToRegisterNewOne(i); delete base; continue; } switch (base->Status()) { case PROCESS: { - TryToRegisterNewOne(); base->Process(); VLOG(4) << cq_name << " PROCESS status:" << base->Status(); break; } case FINISH: { + TryToRegisterNewOne(i); VLOG(4) << cq_name << " FINISH status:" << base->Status(); delete base; break; diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 238aaa2963..d70be1b7ce 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include // NOLINT #include +#include #include "grpc++/grpc++.h" #include "paddle/fluid/framework/blocking_queue.h" @@ -30,6 +31,7 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -82,19 +84,25 @@ class AsyncGRPCServer final { protected: void HandleRequest(::grpc::ServerCompletionQueue *cq, const std::string &cq_name, - std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(); - void TryToRegisterNewGetOne(); - void TryToRegisterNewPrefetchOne(); + std::function TryToRegisterNewOne); + void TryToRegisterNewSendOne(int i); + void TryToRegisterNewGetOne(int i); + void TryToRegisterNewPrefetchOne(int i); void ShutdownQueue(); private: + static const int kSendReqsBufSize = 100; + static const int kGetReqsBufSize = 100; + std::mutex cq_mutex_; volatile bool is_shut_down_ = false; std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; + RequestBase *send_reqs_[kSendReqsBufSize]; + RequestBase *get_reqs_[kGetReqsBufSize]; + GrpcService::AsyncService service_; std::unique_ptr<::grpc::Server> server_; @@ -113,8 +121,9 @@ class AsyncGRPCServer final { mutable int barrier_cond_step_; std::condition_variable barrier_condition_; - std::unique_ptr t_send_; - std::unique_ptr t_get_; + std::vector> t_sends_; + std::vector> t_gets_; + std::unique_ptr t_prefetch_; std::unique_ptr prefetch_ctx_; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index e6dab2f5a3..e0505c2b9d 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -25,6 +25,8 @@ #include #include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/platform/profiler.h" + // NOTE: This method was originally created by tensorflow // (https://github.com/tensorflow/tensorflow/) we borrow this // method and did some modifications so that we can parse gRPC diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 9478c5702b..078181909d 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -73,7 +73,7 @@ message VariableMessage { // If true, the ps server will start profiling, the ps // server stops profiling and generates a profile to /tmp/profile_ps_* // when profile switches from true to false. - bool profile = 11; + int64 profile = 11; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 07c43554bc..a9ea80c917 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -122,7 +122,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, // 1 trainer returns true for ShouldSendProfileState(). It tells PS // servers the trainer's profiling state so that PS can follow the // trainer. - request.set_profile(platform::IsProfileEnabled()); + if (platform::ShouldSendProfileState()) { + if (platform::IsProfileEnabled()) { + request.set_profile(1); + } else { + request.set_profile(2); + } + } if (!out_name.empty()) { request.set_out_varname(out_name); } diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 462e303096..2dfd9b2621 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -449,8 +449,8 @@ int VariableResponse::Parse(Source* source) { break; } case sendrecv::VariableMessage::kProfileFieldNumber: { - bool profiling; - if (!input.ReadRaw(reinterpret_cast(&profiling), 1)) { + uint64_t profiling = 0; + if (!input.ReadVarint64(&profiling)) { return tag; } meta_.set_profile(profiling); @@ -458,9 +458,9 @@ int VariableResponse::Parse(Source* source) { if (listener_id <= 0) { break; } - if (profiling && !platform::IsProfileEnabled()) { + if (profiling == 1 && !platform::IsProfileEnabled()) { platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (!profiling && platform::IsProfileEnabled()) { + } else if (profiling == 2 && platform::IsProfileEnabled()) { // TODO(panyx0718): Should we allow to customize file dir. platform::DisableProfiler( platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index c9e1063168..1a9be044e0 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -245,7 +245,6 @@ class DeviceTracerImpl : public DeviceTracer { void Enable() { std::lock_guard l(trace_mu_); if (enabled_) { - fprintf(stderr, "DeviceTracer already enabled\n"); return; } EnableActivity(); From 11fe3c796be0940e40c3fc96478d0da40c6afde6 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 22 May 2018 00:39:45 -0700 Subject: [PATCH 2/6] clean up --- benchmark/cluster/vgg16/vgg16_fluid.py | 2 +- cmake/external/grpc.cmake | 2 +- paddle/fluid/operators/detail/grpc_server.cc | 64 ++++++++------------ 3 files changed, 27 insertions(+), 41 deletions(-) diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 0f5cd2a253..e9360ab4c7 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -204,7 +204,7 @@ def main(): with profiler.profiler('All', 'total', '/tmp/profile_vgg_%d' % args.task_index): for batch_id, data in enumerate(train_reader()): - if batch_id > 4: break + if batch_id > 5: break run_step(batch_id, data) total_time = 0.0 diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index ef520b1287..e90948782b 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -33,7 +33,7 @@ ExternalProject_Add( extern_grpc DEPENDS protobuf zlib GIT_REPOSITORY "https://github.com/grpc/grpc.git" - GIT_TAG "v1.8.x" + GIT_TAG "v1.10.x" PREFIX ${GRPC_SOURCES_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 604321cd1f..c2c1df4cd6 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -66,11 +66,11 @@ class RequestSend final : public RequestBase { explicit RequestSend(GrpcService::AsyncService* service, ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, - const platform::DeviceContext* dev_ctx, int i) + const platform::DeviceContext* dev_ctx, int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), queue_(queue), responder_(&ctx_), - i_(i) { + req_id_(req_id) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { @@ -79,7 +79,7 @@ class RequestSend final : public RequestBase { int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(i))); + reinterpret_cast(static_cast(req_id))); } virtual ~RequestSend() {} @@ -93,7 +93,7 @@ class RequestSend final : public RequestBase { status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(i_))); + reinterpret_cast(static_cast(req_id_))); } protected: @@ -101,7 +101,7 @@ class RequestSend final : public RequestBase { std::shared_ptr request_; ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; - int i_; + int req_id_; }; class RequestGet final : public RequestBase { @@ -110,16 +110,17 @@ class RequestGet final : public RequestBase { ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - framework::BlockingQueue* queue, int i) + framework::BlockingQueue* queue, + int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), queue_(queue), - i_(i) { + req_id_(req_id) { auto method_id = static_cast(detail::GrpcMethod::kGetVariable); service_->RequestAsyncUnary( method_id, &ctx_, &request_, &responder_, cq_, cq_, - reinterpret_cast(static_cast(i))); + reinterpret_cast(static_cast(req_id_))); } virtual ~RequestGet() {} @@ -138,7 +139,7 @@ class RequestGet final : public RequestBase { status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(i_))); + reinterpret_cast(static_cast(req_id_))); if (var_name == FETCH_BARRIER_MESSAGE) { sendrecv::VariableMessage msg; @@ -153,7 +154,7 @@ class RequestGet final : public RequestBase { ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::BlockingQueue* queue_; - int i_; + int req_id_; }; class RequestPrefetch final : public RequestBase { @@ -165,14 +166,14 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor, framework::ProgramDesc* program, framework::ExecutorPrepareContext* prefetch_ctx, - int i) + int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx), - i_(i) { + req_id_(req_id) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { @@ -202,7 +203,7 @@ class RequestPrefetch final : public RequestBase { SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); responder_.Finish(reply, ::grpc::Status::OK, - reinterpret_cast(static_cast(i_))); + reinterpret_cast(static_cast(req_id_))); status_ = FINISH; } @@ -213,7 +214,7 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; - int i_; + int req_id_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -291,21 +292,6 @@ void AsyncGRPCServer::RunSyncUpdate() { for (int i = 0; i < kNumHandleGetThreads; ++i) { t_gets_[i]->join(); } - { - std::lock_guard l(cq_mutex_); - for (int i = 0; i < kSendReqsBufSize; ++i) { - if (send_reqs_[i]) { - delete send_reqs_[i]; - send_reqs_[i] = nullptr; - } - } - for (int i = 0; i < kGetReqsBufSize; ++i) { - if (get_reqs_[i]) { - delete get_reqs_[i]; - get_reqs_[i] = nullptr; - } - } - } t_prefetch_->join(); } @@ -335,19 +321,19 @@ void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { VLOG(4) << "Create RequestSend status:" << send->Status(); } -void AsyncGRPCServer::TryToRegisterNewGetOne(int i) { +void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, - dev_ctx_, &var_get_queue_, i); - get_reqs_[i] = static_cast(get); + dev_ctx_, &var_get_queue_, req_id); + get_reqs_[req_id] = static_cast(get); VLOG(4) << "Create RequestGet status:" << get->Status(); } -void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { +void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; @@ -355,7 +341,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { } RequestPrefetch* prefetch = new RequestPrefetch( &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, - program_, prefetch_ctx_.get(), i); + program_, prefetch_ctx_.get(), req_id); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } @@ -374,7 +360,7 @@ void AsyncGRPCServer::HandleRequest( break; } VLOG(3) << "HandleRequest for " << cq_name << " get Next"; - int i = static_cast(reinterpret_cast(tag)); + int req_id = static_cast(reinterpret_cast(tag)); if (sync_mode_) { // FIXME(typhoonzero): de-couple the barriers with recv_op @@ -387,9 +373,9 @@ void AsyncGRPCServer::HandleRequest( { std::lock_guard l(cq_mutex_); if (cq_name == "cq_get") { - base = get_reqs_[i]; + base = get_reqs_[req_id]; } else if (cq_name == "cq_send") { - base = send_reqs_[i]; + base = send_reqs_[req_id]; } else { CHECK(false); } @@ -401,7 +387,7 @@ void AsyncGRPCServer::HandleRequest( if (!ok) { LOG(WARNING) << cq_name << " recv no regular event:argument name[" << base->GetReqName() << "]"; - TryToRegisterNewOne(i); + TryToRegisterNewOne(req_id); delete base; continue; } @@ -413,7 +399,7 @@ void AsyncGRPCServer::HandleRequest( break; } case FINISH: { - TryToRegisterNewOne(i); + TryToRegisterNewOne(req_id); VLOG(4) << cq_name << " FINISH status:" << base->Status(); delete base; break; From 722c078b154b0b9dd97bb4f9c0bfe391348143a7 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 22 May 2018 04:47:47 -0700 Subject: [PATCH 3/6] fix test and clean up --- paddle/fluid/operators/detail/grpc_server.cc | 37 ++++++++++++-------- paddle/fluid/operators/detail/grpc_server.h | 3 ++ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index c2c1df4cd6..51ddda6255 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -25,6 +25,7 @@ namespace detail { namespace { const int kNumHandleSendThreads = 20; const int kNumHandleGetThreads = 20; +const int kNumHandlePrefetchThreads = 1; } // namespace enum CallStatus { PROCESS = 0, FINISH }; @@ -180,8 +181,9 @@ class RequestPrefetch final : public RequestBase { request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id_))); } virtual ~RequestPrefetch() {} @@ -190,7 +192,6 @@ class RequestPrefetch final : public RequestBase { virtual void Process() { // prefetch process... - ::grpc::ByteBuffer reply; std::string var_name = request_->OutVarname(); VLOG(3) << "RequestPrefetch " << var_name; @@ -200,15 +201,16 @@ class RequestPrefetch final : public RequestBase { InitializeVariable(var, var_desc->GetType()); executor_->RunPreparedContext(prefetch_ctx_, scope_); - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); - responder_.Finish(reply, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(req_id_))); } protected: std::shared_ptr request_; + ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::Executor* executor_; @@ -262,6 +264,9 @@ void AsyncGRPCServer::RunSyncUpdate() { for (int i = 0; i < kGetReqsBufSize; ++i) { TryToRegisterNewGetOne(i); } + for (int i = 0; i < kPrefetchReqsBufSize; ++i) { + TryToRegisterNewPrefetchOne(i); + } for (int i = 0; i < kNumHandleSendThreads; ++i) { t_sends_.emplace_back( @@ -273,12 +278,11 @@ void AsyncGRPCServer::RunSyncUpdate() { new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); } - - // TODO(wuyi): Run these "HandleRequest" in thread pool - t_prefetch_.reset(new std::thread( - std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), - "cq_prefetch", prefetch_register))); - + for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + t_prefetchs_.emplace_back(new std::thread( + std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), + "cq_prefetch", prefetch_register))); + } { std::lock_guard lock(this->mutex_ready_); ready_ = 1; @@ -292,7 +296,9 @@ void AsyncGRPCServer::RunSyncUpdate() { for (int i = 0; i < kNumHandleGetThreads; ++i) { t_gets_[i]->join(); } - t_prefetch_->join(); + for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + t_prefetchs_[i]->join(); + } } void AsyncGRPCServer::ShutdownQueue() { @@ -342,6 +348,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { RequestPrefetch* prefetch = new RequestPrefetch( &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, program_, prefetch_ctx_.get(), req_id); + prefetch_reqs_[req_id] = static_cast(prefetch); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } @@ -376,8 +383,8 @@ void AsyncGRPCServer::HandleRequest( base = get_reqs_[req_id]; } else if (cq_name == "cq_send") { base = send_reqs_[req_id]; - } else { - CHECK(false); + } else if (cq_name == "cq_prefetch") { + base = prefetch_reqs_[req_id]; } } // reference: diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index d70be1b7ce..9a60ee5579 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -93,6 +93,7 @@ class AsyncGRPCServer final { private: static const int kSendReqsBufSize = 100; static const int kGetReqsBufSize = 100; + static const int kPrefetchReqsBufSize = 10; std::mutex cq_mutex_; volatile bool is_shut_down_ = false; @@ -102,6 +103,7 @@ class AsyncGRPCServer final { RequestBase *send_reqs_[kSendReqsBufSize]; RequestBase *get_reqs_[kGetReqsBufSize]; + RequestBase *prefetch_reqs_[kPrefetchReqsBufSize]; GrpcService::AsyncService service_; std::unique_ptr<::grpc::Server> server_; @@ -123,6 +125,7 @@ class AsyncGRPCServer final { std::vector> t_sends_; std::vector> t_gets_; + std::vector> t_prefetchs_; std::unique_ptr t_prefetch_; From a848303e10b77a61108ec22e48c02d20d4eeafaa Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 22 May 2018 04:55:21 -0700 Subject: [PATCH 4/6] follow comments --- paddle/fluid/framework/executor.cc | 5 ++++- paddle/fluid/operators/detail/sendrecvop_utils.cc | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 55be9b6c3b..4e431561f8 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -350,9 +350,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } } - // platform::DeviceContextPool::Instance().Get(place_)->Wait(); + platform::DeviceContextPool::Instance().Get(place_)->Wait(); if (create_vars && create_local_scope) { scope->DeleteScope(local_scope); + } else { + // Delete the local scopes created in operators. + scope->DropKids(); } if (FLAGS_benchmark) { VLOG(2) << "-------------------------------------------------------"; diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index a9ea80c917..a0d3345685 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -31,6 +31,10 @@ limitations under the License. */ namespace paddle { namespace operators { namespace detail { +namespace { +const int kStartProfile = 1; +const int kStopProfile = 2; +} // namespace using VarMsg = sendrecv::VariableMessage; @@ -124,9 +128,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, // trainer. if (platform::ShouldSendProfileState()) { if (platform::IsProfileEnabled()) { - request.set_profile(1); + request.set_profile(kStartProfile); } else { - request.set_profile(2); + request.set_profile(kStopProfile); } } if (!out_name.empty()) { From 08e4970e458a068c76af8ba89c78403b45c430d0 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 23 May 2018 01:18:09 -0700 Subject: [PATCH 5/6] follow comments --- paddle/fluid/operators/detail/grpc_server.cc | 24 ++++++++++--------- paddle/fluid/operators/detail/grpc_server.h | 6 ++--- .../operators/detail/sendrecvop_utils.cc | 8 ++----- .../operators/detail/variable_response.cc | 6 +++-- paddle/fluid/platform/profiler.h | 2 ++ 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 51ddda6255..58faead2bd 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -19,14 +19,16 @@ limitations under the License. */ using ::grpc::ServerAsyncResponseWriter; +DEFINE_int32(rpc_server_handle_send_threads, 20, + "Number of threads used to handle send at rpc server."); +DEFINE_int32(rpc_server_handle_get_threads, 20, + "Number of threads used to handle get at rpc server."); +DEFINE_int32(rpc_server_handle_prefetch_threads, 1, + "Number of threads used to handle prefetch at rpc server."); + namespace paddle { namespace operators { namespace detail { -namespace { -const int kNumHandleSendThreads = 20; -const int kNumHandleGetThreads = 20; -const int kNumHandlePrefetchThreads = 1; -} // namespace enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -268,17 +270,17 @@ void AsyncGRPCServer::RunSyncUpdate() { TryToRegisterNewPrefetchOne(i); } - for (int i = 0; i < kNumHandleSendThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { t_sends_.emplace_back( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); } - for (int i = 0; i < kNumHandleGetThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { t_gets_.emplace_back( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); } - for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { t_prefetchs_.emplace_back(new std::thread( std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), "cq_prefetch", prefetch_register))); @@ -290,13 +292,13 @@ void AsyncGRPCServer::RunSyncUpdate() { condition_ready_.notify_all(); // wait server server_->Wait(); - for (int i = 0; i < kNumHandleSendThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { t_sends_[i]->join(); } - for (int i = 0; i < kNumHandleGetThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { t_gets_[i]->join(); } - for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { t_prefetchs_[i]->join(); } } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 9a60ee5579..bdff9801a9 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -85,9 +85,9 @@ class AsyncGRPCServer final { void HandleRequest(::grpc::ServerCompletionQueue *cq, const std::string &cq_name, std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(int i); - void TryToRegisterNewGetOne(int i); - void TryToRegisterNewPrefetchOne(int i); + void TryToRegisterNewSendOne(int req_id); + void TryToRegisterNewGetOne(int req_id); + void TryToRegisterNewPrefetchOne(int req_id); void ShutdownQueue(); private: diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index a0d3345685..0601988351 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -31,10 +31,6 @@ limitations under the License. */ namespace paddle { namespace operators { namespace detail { -namespace { -const int kStartProfile = 1; -const int kStopProfile = 2; -} // namespace using VarMsg = sendrecv::VariableMessage; @@ -128,9 +124,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, // trainer. if (platform::ShouldSendProfileState()) { if (platform::IsProfileEnabled()) { - request.set_profile(kStartProfile); + request.set_profile(platform::kEnableProfiler); } else { - request.set_profile(kStopProfile); + request.set_profile(platform::kDisableProfiler); } } if (!out_name.empty()) { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 2dfd9b2621..24cb91a3bb 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) { if (listener_id <= 0) { break; } - if (profiling == 1 && !platform::IsProfileEnabled()) { + if (profiling == platform::kEnableProfiler && + !platform::IsProfileEnabled()) { platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (profiling == 2 && platform::IsProfileEnabled()) { + } else if (profiling == platform::kDisableProfiler && + platform::IsProfileEnabled()) { // TODO(panyx0718): Should we allow to customize file dir. platform::DisableProfiler( platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 643bb6183d..bf43925373 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -116,6 +116,8 @@ void ResetProfiler(); void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path); +const int kEnableProfiler = 1; +const int kDisableProfiler = 2; // Test if the profiler is currently enabled. bool IsProfileEnabled(); // Whether the trainer should send profiling state to PS. From 2643868c664832b8bec301fe32b93659d4678d5a Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 23 May 2018 16:20:24 +0800 Subject: [PATCH 6/6] follow comments --- paddle/fluid/operators/detail/send_recv.proto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 078181909d..a244afc46f 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -70,9 +70,9 @@ message VariableMessage { bytes rows = 9; // Look up table block execution output variable name. string out_varname = 10; - // If true, the ps server will start profiling, the ps + // If 1, the ps server will start profiling, the ps // server stops profiling and generates a profile to /tmp/profile_ps_* - // when profile switches from true to false. + // when profile switches from 1 to 2. int64 profile = 11; }