From 847e4f4e854b3f73625816d152f65ca5f5c7a27e Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 1 Mar 2019 11:24:14 +0800 Subject: [PATCH] pure async mode train --- .../details/async_ssa_graph_executor.cc | 114 ++++++++++++------ .../details/async_ssa_graph_executor.h | 12 ++ .../details/threaded_ssa_graph_executor.cc | 2 + paddle/fluid/framework/parallel_executor.cc | 8 +- paddle/fluid/framework/reader.cc | 5 +- paddle/fluid/framework/reader.h | 10 +- .../fluid/operators/reader/blocking_queue.h | 3 +- .../fluid/operators/reader/buffered_reader.cc | 3 + .../operators/reader/create_py_reader_op.cc | 7 +- .../reader/lod_tensor_blocking_queue.h | 5 +- paddle/fluid/pybind/pybind.cc | 1 + .../test_async_ssa_graph_executor_mnist.py | 41 ++++--- 12 files changed, 148 insertions(+), 63 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index dfb9d73dcb..69f770afee 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -14,10 +14,31 @@ #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" +#include "paddle/fluid/framework/variable_helper.h" + namespace paddle { namespace framework { namespace details { +inline void NewTempScopeAndInitVars(const std::vector &var_infos, + Scope *scope) { + Scope &local_scope = scope->NewScope(); + *scope->Var(details::kLocalExecScopeName)->GetMutable() = + &local_scope; + + for (auto &info : var_infos) { + if (scope->FindVar(info.name_) != nullptr) { + continue; + } + + if (info.persistable_) { // Persistable + InitializeVariable(scope->Var(info.name_), info.type_); + } else { + InitializeVariable(local_scope.Var(info.name_), info.type_); + } + } +} + AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, std::vector graphs) @@ -39,58 +60,81 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor( strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i])); } -} -FeedFetchList AsyncSSAGraphExecutor::Run( - const std::vector &fetch_tensors) { - std::vector> run_futures; - - std::vector fetch_data; - FeedFetchList ret; - - fetch_data.reserve(places_.size()); - ret.reserve(fetch_tensors.size()); - exception_holder_.Clear(); + for (auto &node : graphs_[0]->Nodes()) { + if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { + var_infos_.emplace_back(); + var_infos_.back().name_ = node->Var()->Name(); + var_infos_.back().type_ = node->Var()->GetType(); + var_infos_.back().persistable_ = node->Var()->Persistable(); + } + } + for (auto *scope : local_scopes_) { + NewTempScopeAndInitVars(var_infos_, scope); + } +} - for (size_t i = 0; i < places_.size(); ++i) { - auto call = [this, i, &fetch_tensors]() -> FeedFetchList { +void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() { + VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size(); + for (size_t i = 1; i < places_.size(); ++i) { + auto call = [this, i]() -> void { + VLOG(3) << "start off python thread " << i; try { - return executors_[i]->Run(fetch_tensors); + while (true) { + executors_[i]->Run({}); + } } catch (...) { exception_holder_.Catch(std::current_exception()); + VLOG(3) << "get exception type = " << exception_holder_.Type(); } - return FeedFetchList(); + VLOG(3) << "thread " << i << " exited!"; }; - - if (pool_) { - run_futures.emplace_back(pool_->enqueue(std::move(call))); - } else { - fetch_data.emplace_back(std::move(call())); - } - } - - if (pool_) { - for (auto &f : run_futures) { - if (exception_holder_.IsCaught()) { - f.wait(); - } else { - fetch_data.emplace_back(std::move(f.get())); - } - } + run_futures_.emplace_back(pool_->enqueue(std::move(call))); } +} +void AsyncSSAGraphExecutor::HandleException() { if (exception_holder_.IsCaught()) { + for (auto &f : run_futures_) { + VLOG(3) << "wait future"; + f.wait(); + } VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it"; + run_futures_.clear(); exception_holder_.ReThrow(); } +} + +FeedFetchList AsyncSSAGraphExecutor::Run( + const std::vector &fetch_tensors) { + // init once + if (run_futures_.size() == 0 && places_.size() > 1) { + exception_holder_.Clear(); + StartOffPythonTrainLoop(); + } + + if (places_.size() == 1) { + exception_holder_.Clear(); + } else { + HandleException(); + } + + FeedFetchList fetch_data; + fetch_data.reserve(fetch_tensors.size()); + + try { + fetch_data = executors_[0]->Run(fetch_tensors); + } catch (...) { + exception_holder_.Catch(std::current_exception()); + } + + HandleException(); + FeedFetchList ret; for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { std::vector lodtensor_ptrs; - lodtensor_ptrs.reserve(local_scopes_.size()); - for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) { - lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx)); - } + lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx)); ret.emplace_back(); ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); } diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.h b/paddle/fluid/framework/details/async_ssa_graph_executor.h index ff85ba2c6c..7d7296772d 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.h @@ -24,6 +24,12 @@ namespace paddle { namespace framework { namespace details { +struct VarInfo { + std::string name_; + proto::VarType::Type type_; + bool persistable_; +}; + class AsyncSSAGraphExecutor : public SSAGraphExecutor { public: AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, @@ -35,6 +41,10 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { FeedFetchList Run(const std::vector &fetch_tensors) override; + private: + void StartOffPythonTrainLoop(); + void HandleException(); + private: ExecutionStrategy strategy_; std::vector local_scopes_; @@ -44,6 +54,8 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { std::vector> executors_; ExceptionHolder exception_holder_; + std::vector> run_futures_; + std::vector var_infos_; }; } // namespace details diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 8436626362..fa0c90e1f4 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -119,6 +119,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( if (timeout) { if (exception_holder_.IsCaught()) { + VLOG(3) << "caught exception " << exception_holder_.Type() + << ", rethrow it"; for (auto &run_op_future : run_op_futures_) { run_op_future.wait(); } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b1f4091148..c133772e6e 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -379,9 +379,11 @@ ParallelExecutor::ParallelExecutor( } VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; - member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, std::move(var_infos), - member_->places_, std::move(member_->executor_))); + if (!build_strategy.async_mode_) { + member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( + exec_strategy, member_->local_scopes_, std::move(var_infos), + member_->places_, std::move(member_->executor_))); + } } void ParallelExecutor::BCastParamsToDevices( diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 40eafda9bf..d3513fb7db 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -69,6 +69,9 @@ void ReaderBase::Start() { ReaderBase::~ReaderBase() {} -DecoratedReader::~DecoratedReader() { reader_->Shutdown(); } +DecoratedReader::~DecoratedReader() { + VLOG(1) << "~DecoratedReader"; + reader_->Shutdown(); +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 82562bf883..6cf0ec2937 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -77,7 +77,10 @@ class DecoratedReader : public ReaderBase, ~DecoratedReader(); protected: - void ShutdownImpl() override { reader_->Shutdown(); } + void ShutdownImpl() override { + VLOG(1) << "ShutdownImpl"; + reader_->Shutdown(); + } void StartImpl() override { reader_->Start(); } @@ -98,6 +101,8 @@ class ReaderHolder { reader_ = reader_base; } + ~ReaderHolder() { VLOG(1) << "~ReaderHolder"; } + const std::shared_ptr& Get() const { return reader_; } void ReadNext(std::vector* out) { @@ -106,6 +111,7 @@ class ReaderHolder { } void ResetAll() { + VLOG(1) << "ResetAll"; auto end_readers = reader_->GetEndPoints(); for (auto* reader : end_readers) { reader->Shutdown(); @@ -116,11 +122,13 @@ class ReaderHolder { } void Shutdown() { + VLOG(1) << "Shutdown"; PADDLE_ENFORCE_NOT_NULL(reader_); reader_->Shutdown(); } void Start() { + VLOG(1) << "start"; PADDLE_ENFORCE_NOT_NULL(reader_); reader_->Start(); } diff --git a/paddle/fluid/operators/reader/blocking_queue.h b/paddle/fluid/operators/reader/blocking_queue.h index c99b2bc593..fe3f2f4031 100644 --- a/paddle/fluid/operators/reader/blocking_queue.h +++ b/paddle/fluid/operators/reader/blocking_queue.h @@ -86,6 +86,7 @@ class BlockingQueue { void ReOpen() { std::lock_guard lock(mutex_); + VLOG(1) << "reopen queue"; closed_ = false; std::deque new_deque; queue_.swap(new_deque); @@ -95,7 +96,7 @@ class BlockingQueue { void Close() { std::lock_guard lock(mutex_); - VLOG(3) << "close queue"; + VLOG(1) << "close queue"; closed_ = true; send_cv_.notify_all(); receive_cv_.notify_all(); diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index defc29b91f..db80fda695 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -20,6 +20,7 @@ namespace paddle { namespace operators { namespace reader { BufferedReader::~BufferedReader() { + VLOG(1) << "~BufferedReader"; reader_->Shutdown(); while (!position_.empty()) { position_.front().wait(); @@ -41,6 +42,7 @@ BufferedReader::BufferedReader( thread_pool_(1), place_(place), buffer_size_(buffer_size) { + VLOG(1) << "BufferedReader"; #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); @@ -121,6 +123,7 @@ void BufferedReader::ReadAsync(size_t i) { } void BufferedReader::ShutdownImpl() { + VLOG(1) << "ShutdownImpl"; reader_->Shutdown(); while (!position_.empty()) { position_.pop(); diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index b2469ad0eb..2916be618c 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -33,10 +33,13 @@ class PyReader : public framework::FileReader { if (!success) out->clear(); } - ~PyReader() { queue_->Close(); } + ~PyReader() { + VLOG(1) << "~PyReader"; + queue_->Close(); + } void Shutdown() override { - VLOG(3) << "PyReader shutdown!"; + VLOG(1) << "PyReader shutdown!"; queue_->Close(); } diff --git a/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h b/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h index 5b53edff5d..eeba330d66 100644 --- a/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h +++ b/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h @@ -57,7 +57,10 @@ class LoDTensorBlockingQueue { inline void ReOpen() { queue_.ReOpen(); } - inline void Close() { queue_.Close(); } + inline void Close() { + VLOG(1) << "LoDTensorBlockingQueue close"; + queue_.Close(); + } inline bool IsClosed() const { return queue_.IsClosed(); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fdee5a6d66..af049127aa 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -557,6 +557,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_lod_tensor_blocking_queue", [](Variable &var, size_t capacity) -> std::shared_ptr { + VLOG(1) << "init_lod_tensor_blocking_queue"; auto *holder = var.GetMutable(); holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode); return holder->GetQueue(); diff --git a/python/paddle/fluid/tests/unittests/test_async_ssa_graph_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_async_ssa_graph_executor_mnist.py index 41fa39e06b..4fbda407f1 100644 --- a/python/paddle/fluid/tests/unittests/test_async_ssa_graph_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_async_ssa_graph_executor_mnist.py @@ -36,7 +36,7 @@ def convolutional_neural_network(use_py_reader): capacity=64, feed_list=[img, label], name='py_reader', - use_double_buffer=True) + use_double_buffer=False) img, label = fluid.layers.read_file(py_reader) conv_pool_1 = fluid.nets.simple_img_conv_pool( @@ -139,20 +139,21 @@ def train(use_cuda, thread_num, cpu_num): exec_strategy=exec_strategy) py_reader.decorate_paddle_reader(train_reader) - py_reader.start() - - step = 0 - try: - while True: - loss_val = pe.run(fetch_list=[avg_loss.name]) - loss_val = numpy.mean(loss_val) - if step % 100 == 0: - print("Batch %d, Cost %f, queue size %d" % - (step, loss_val, py_reader.queue.size())) - step += 1 - except fluid.core.EOFException: - print("train end") - py_reader.reset() + + for pass_id in range(2): + step = 0 + py_reader.start() + try: + while True: + loss_val = pe.run(fetch_list=[avg_loss.name]) + loss_val = numpy.mean(loss_val) + if step % 10 == 0: + print("Pass %d, Batch %d, Cost %f, queue size %d" % + (pass_id, step, loss_val, py_reader.queue.size())) + step += 1 + except fluid.core.EOFException: + print("train end pass = " + str(pass_id)) + py_reader.reset() return step @@ -161,10 +162,11 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase): def test_check_async_ssa_exe_train(self): step_list = [] for cpu_num in [1, 2, 4]: - scope = fluid.core.Scope() - with fluid.scope_guard(scope): + print("run cpu_num -> " + str(cpu_num)) + with fluid.scope_guard(fluid.core.Scope()): with fluid.program_guard( - fluid.Program(), startup_program=fluid.Program()): + main_program=fluid.Program(), + startup_program=fluid.Program()): start_time = time.time() step = train( use_cuda=False, thread_num=cpu_num, cpu_num=cpu_num) @@ -173,7 +175,8 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase): print("cpu_num -> " + str(cpu_num) + " step -> " + str(step) + " time -> " + str(end_time - start_time)) with fluid.program_guard( - fluid.Program(), startup_program=fluid.Program()): + main_program=fluid.Program(), + startup_program=fluid.Program()): test() assert int(step_list[0] / 2) == int(step_list[1]) assert int(step_list[1] / 2) == int(step_list[2])