multi-thread handlerequest

Experiment on vgg flower, 2 trainers, 1ps.
    more trainer could have more speedup.

    After:
    Pass = 0, Iters = 327, Speed = (7.52) img/s
    Before:
    Pass = 0, Iters = 385, Speed = (6.77) img/s
shanyi15-patch-3
Xin Pan 7 years ago
parent ebefdbe372
commit b4dd4c048d

@ -38,7 +38,7 @@ def str2bool(v):
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=128, help="Batch size for training.") '--batch_size', type=int, default=16, help="Batch size for training.")
parser.add_argument( parser.add_argument(
'--learning_rate', '--learning_rate',
type=float, type=float,
@ -61,7 +61,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--data_set', '--data_set',
type=str, type=str,
default='cifar10', default='flowers',
choices=['cifar10', 'flowers'], choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.') help='Optional dataset for benchmark.')
parser.add_argument( parser.add_argument(
@ -200,26 +200,30 @@ def main():
fetch_list=[avg_cost, batch_acc, batch_size]) fetch_list=[avg_cost, batch_acc, batch_size])
return loss, acc, b_size return loss, acc, b_size
if args.profile and args.task_index == 0: if args.profile:
# warmup. with profiler.profiler('All', 'total',
for batch_id, data in enumerate(train_reader()): '/tmp/profile_vgg_%d' % args.task_index):
if batch_id > 5: break
run_step(batch_id, data)
with profiler.profiler('All', 'total', '/tmp/profile_vgg'):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break if batch_id > 4: break
run_step(batch_id, data) run_step(batch_id, data)
total_time = 0.0
count = 0
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
ts = time.time() ts = time.time()
loss, acc, b_size = run_step(batch_id, data) loss, acc, b_size = run_step(batch_id, data)
iters += 1 iters += 1
num_samples += len(data) num_samples += len(data)
train_pass_acc.add(value=acc, weight=b_size) train_pass_acc.add(value=acc, weight=b_size)
duration = time.time() - ts
total_time += duration
count += len(data)
print( print(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s" % (pass_id, iters, loss, acc, "Speed = %.2f (%.2f) img/s" % (pass_id, iters, loss, acc,
len(data) / (time.time() - ts)) len(data) / duration,
count / total_time)
) # The accuracy is the accumulation of batches, but not the current batch. ) # The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed = time.time() - start_time pass_elapsed = time.time() - start_time

@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc extern_grpc
DEPENDS protobuf zlib DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git" GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.10.x" GIT_TAG "v1.8.x"
PREFIX ${GRPC_SOURCES_DIR} PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

@ -350,12 +350,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
} }
} }
platform::DeviceContextPool::Instance().Get(place_)->Wait(); // platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (create_vars && create_local_scope) { if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} else {
// Delete the local scopes created in operators.
scope->DropKids();
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------"; VLOG(2) << "-------------------------------------------------------";

@ -19,6 +19,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -196,9 +197,14 @@ bool RPCClient::Wait() {
const size_t kReqCnt = req_count_; const size_t kReqCnt = req_count_;
bool a[kReqCnt]; bool a[kReqCnt];
std::vector<std::future<void>> waits(req_count_); std::vector<std::future<void>> waits(req_count_);
std::mutex mu;
for (int i = 0; i < req_count_; i++) { for (int i = 0; i < req_count_; i++) {
waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); waits[i] = framework::AsyncIO([i, &a, &mu, this] {
bool ret = Proceed();
std::lock_guard<std::mutex> l(mu);
a[i] = ret;
});
} }
for (int i = 0; i < req_count_; i++) { for (int i = 0; i < req_count_; i++) {

File diff suppressed because it is too large Load Diff

@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
#include <vector>
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
@ -30,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -82,19 +84,25 @@ class AsyncGRPCServer final {
protected: protected:
void HandleRequest(::grpc::ServerCompletionQueue *cq, void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name, const std::string &cq_name,
std::function<void()> TryToRegisterNewOne); std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(); void TryToRegisterNewSendOne(int i);
void TryToRegisterNewGetOne(); void TryToRegisterNewGetOne(int i);
void TryToRegisterNewPrefetchOne(); void TryToRegisterNewPrefetchOne(int i);
void ShutdownQueue(); void ShutdownQueue();
private: private:
static const int kSendReqsBufSize = 100;
static const int kGetReqsBufSize = 100;
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize];
GrpcService::AsyncService service_; GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
@ -113,8 +121,9 @@ class AsyncGRPCServer final {
mutable int barrier_cond_step_; mutable int barrier_cond_step_;
std::condition_variable barrier_condition_; std::condition_variable barrier_condition_;
std::unique_ptr<std::thread> t_send_; std::vector<std::unique_ptr<std::thread>> t_sends_;
std::unique_ptr<std::thread> t_get_; std::vector<std::unique_ptr<std::thread>> t_gets_;
std::unique_ptr<std::thread> t_prefetch_; std::unique_ptr<std::thread> t_prefetch_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_; std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;

@ -25,6 +25,8 @@
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow // NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this // (https://github.com/tensorflow/tensorflow/) we borrow this
// method and did some modifications so that we can parse gRPC // method and did some modifications so that we can parse gRPC

@ -73,7 +73,7 @@ message VariableMessage {
// If true, the ps server will start profiling, the ps // If true, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_* // server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from true to false. // when profile switches from true to false.
bool profile = 11; int64 profile = 11;
} }
message VoidMessage {} message VoidMessage {}

@ -122,7 +122,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
// 1 trainer returns true for ShouldSendProfileState(). It tells PS // 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the // servers the trainer's profiling state so that PS can follow the
// trainer. // trainer.
request.set_profile(platform::IsProfileEnabled()); if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(1);
} else {
request.set_profile(2);
}
}
if (!out_name.empty()) { if (!out_name.empty()) {
request.set_out_varname(out_name); request.set_out_varname(out_name);
} }

@ -449,8 +449,8 @@ int VariableResponse::Parse(Source* source) {
break; break;
} }
case sendrecv::VariableMessage::kProfileFieldNumber: { case sendrecv::VariableMessage::kProfileFieldNumber: {
bool profiling; uint64_t profiling = 0;
if (!input.ReadRaw(reinterpret_cast<void*>(&profiling), 1)) { if (!input.ReadVarint64(&profiling)) {
return tag; return tag;
} }
meta_.set_profile(profiling); meta_.set_profile(profiling);
@ -458,9 +458,9 @@ int VariableResponse::Parse(Source* source) {
if (listener_id <= 0) { if (listener_id <= 0) {
break; break;
} }
if (profiling && !platform::IsProfileEnabled()) { if (profiling == 1 && !platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU); platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (!profiling && platform::IsProfileEnabled()) { } else if (profiling == 2 && platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir. // TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler( platform::DisableProfiler(
platform::EventSortingKey::kDefault, platform::EventSortingKey::kDefault,

@ -245,7 +245,6 @@ class DeviceTracerImpl : public DeviceTracer {
void Enable() { void Enable() {
std::lock_guard<std::mutex> l(trace_mu_); std::lock_guard<std::mutex> l(trace_mu_);
if (enabled_) { if (enabled_) {
fprintf(stderr, "DeviceTracer already enabled\n");
return; return;
} }
EnableActivity(); EnableActivity();

Loading…
Cancel
Save