From ca9d21f51b6623df848f82b901bf0237c3a92687 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Wed, 8 Aug 2018 11:01:11 +0800 Subject: [PATCH 01/11] Fix #12578: Wrong error message when run out of GPU memory --- paddle/fluid/memory/detail/buddy_allocator.cc | 2 ++ paddle/fluid/memory/detail/buddy_allocator.h | 2 ++ paddle/fluid/memory/malloc.cc | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index 01a8501dd4..3c961e5040 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -162,6 +162,8 @@ void BuddyAllocator::Free(void* p) { } size_t BuddyAllocator::Used() { return total_used_; } +size_t BuddyAllocator::GetMinChunkSize() {return min_chunk_size_;}; +size_t BuddyAllocator::GetMaxChunkSize() {return max_chunk_size_;}; void* BuddyAllocator::SystemAlloc(size_t size) { size_t index = 0; diff --git a/paddle/fluid/memory/detail/buddy_allocator.h b/paddle/fluid/memory/detail/buddy_allocator.h index f0c83efc23..3f86a51f0d 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.h +++ b/paddle/fluid/memory/detail/buddy_allocator.h @@ -42,6 +42,8 @@ class BuddyAllocator { void* Alloc(size_t unaligned_size); void Free(void* ptr); size_t Used(); + size_t GetMinChunkSize(); + size_t GetMaxChunkSize(); public: // Disable copy and assignment diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index 7c800b3c16..283745e977 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -119,8 +119,8 @@ void* Alloc(platform::CUDAPlace place, size_t size) { LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU " << place.device << ", available " << avail << " bytes"; LOG(WARNING) << "total " << total; - LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize(); - LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize(); + LOG(WARNING) << "GpuMinChunkSize " << buddy_allocator->GetMinChunkSize(); + LOG(WARNING) << "GpuMaxChunkSize " << buddy_allocator->GetMaxChunkSize(); LOG(WARNING) << "GPU memory used: " << Used(place); platform::SetDeviceId(cur_dev); } From 755edc2c4bb928c8381ec7e14d3d13244bdc4922 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 27 Aug 2018 16:03:48 +0800 Subject: [PATCH 02/11] Accelerate python35 ci job --- paddle/scripts/paddle_build.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 8460f93b84..8073046700 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -625,6 +625,12 @@ function main() { test_fluid_inference_lib assert_api_not_changed ;; + cicheck_py35) + cmake_gen ${PYTHON_ABI:-""} + build + run_test + assert_api_not_changed + ;; *) print_usage exit 0 From cc18fffb9000d4b5b9352568f341844c72d14fe1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 11 Sep 2018 12:05:25 +0800 Subject: [PATCH 03/11] add nest while_op --- paddle/fluid/operators/while_op.cc | 5 ++-- .../fluid/tests/unittests/test_while_op.py | 25 ++++++++++++++++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 65a3bc928e..791138a8c0 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -63,7 +63,7 @@ class WhileOp : public framework::OperatorBase { while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); - executor.RunPreparedContext(ctx.get(), ¤t_scope, false); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true); if (is_test) { scope.DeleteScope(¤t_scope); } @@ -169,7 +169,8 @@ class WhileGradOp : public framework::OperatorBase { } } } - executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false); + executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true, + true); auto &pg_names = Outputs(kXGRAD); auto &p_names = Inputs(kX); diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index b75373cf24..43fd9d425b 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,8 +30,10 @@ class TestWhileOp(unittest.TestCase): "d1", shape=[10], append_batch_size=False, dtype='float32') d2 = layers.data( "d2", shape=[10], append_batch_size=False, dtype='float32') + i = layers.zeros(shape=[1], dtype='int64') i.stop_gradient = True + init = layers.zeros(shape=[10], dtype='float32') mem_array = layers.array_write(x=init, i=i) data_array = layers.array_write(x=d0, i=i) @@ -45,11 +47,19 @@ class TestWhileOp(unittest.TestCase): i = layers.zeros(shape=[1], dtype='int64') i.stop_gradient = True - array_len = layers.fill_constant(shape=[1], dtype='int64', value=3) + array_len = layers.fill_constant(shape=[1], dtype='int64', value=1) array_len.stop_gradient = True cond = layers.less_than(x=i, y=array_len) + j = layers.fill_constant(shape=[1], dtype='int64', value=1) + j.stop_gradient = True + + array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3) + array_len2.stop_gradient = True + cond2 = layers.less_than(x=j, y=array_len2) + while_op = layers.While(cond=cond) + while_op2 = layers.While(cond=cond2) with while_op.block(): d = layers.array_read(array=data_array, i=i) prev = layers.array_read(array=mem_array, i=i) @@ -59,7 +69,16 @@ class TestWhileOp(unittest.TestCase): layers.array_write(result, i=i, array=mem_array) layers.less_than(x=i, y=array_len, cond=cond) - sum_result = layers.array_read(array=mem_array, i=i) + with while_op2.block(): + d2 = layers.array_read(array=data_array, i=j) + prev2 = layers.array_read(array=mem_array, i=j) + result2 = layers.sums(input=[d2, prev2]) + + j = layers.increment(x=j, in_place=True) + layers.array_write(result2, i=j, array=mem_array) + layers.less_than(x=j, y=array_len2, cond=cond2) + + sum_result = layers.array_read(array=mem_array, i=j) loss = layers.mean(sum_result) append_backward(loss) From 24ea39c4c653fe8ebadc0b520ca009446662f872 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Sat, 15 Sep 2018 14:50:30 +0000 Subject: [PATCH 04/11] feature/eager_delete_tensor --- paddle/fluid/framework/details/CMakeLists.txt | 15 +- .../framework/details/computation_op_handle.h | 6 + .../fluid/framework/details/op_handle_base.h | 7 + .../details/reference_count_op_handle.h | 123 ++++++++++++ .../framework/details/reference_count_pass.cc | 152 +++++++++++++++ .../framework/details/reference_count_pass.h | 37 ++++ .../scope_buffered_ssa_graph_executor.cc | 20 ++ paddle/fluid/framework/executor.cc | 78 +++++++- paddle/fluid/framework/executor.h | 45 +++++ paddle/fluid/framework/garbage_collector.h | 163 ++++++++++++++++ paddle/fluid/framework/ir/graph.h | 183 ++++++++++++++++++ paddle/fluid/framework/parallel_executor.cc | 53 ++++- paddle/fluid/framework/parallel_executor.h | 20 +- paddle/fluid/framework/scope.cc | 12 ++ paddle/fluid/framework/scope.h | 2 + paddle/fluid/framework/tensor.h | 2 + paddle/fluid/platform/CMakeLists.txt | 4 +- paddle/fluid/platform/device_context.cc | 3 + paddle/fluid/platform/device_context.h | 23 ++- .../fluid/platform/stream_callback_manager.h | 82 ++++++++ python/paddle/fluid/__init__.py | 16 +- 21 files changed, 1023 insertions(+), 23 deletions(-) create mode 100644 paddle/fluid/framework/details/reference_count_op_handle.h create mode 100644 paddle/fluid/framework/details/reference_count_pass.cc create mode 100644 paddle/fluid/framework/details/reference_count_pass.h create mode 100644 paddle/fluid/framework/garbage_collector.h create mode 100644 paddle/fluid/framework/ir/graph.h create mode 100644 paddle/fluid/platform/stream_callback_manager.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4fb4ec38ee..8404bf4a3e 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -29,13 +29,20 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) +if(WITH_GPU) + cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle + all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) +endif() +cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) -cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) +if(WITH_GPU) + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) +else() + cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) +endif() -cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index f048f973fd..401ebb7953 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -23,6 +23,8 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/framework/details/reference_count_op_handle.h" + namespace paddle { namespace framework { namespace details { @@ -33,6 +35,10 @@ struct ComputationOpHandle : public OpHandleBase { std::string Name() const override; + const Scope *GetScope() const { return scope_; } + + const platform::Place &GetPlace() const { return place_; } + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 6aec178831..3de22a0235 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -82,6 +82,13 @@ class OpHandleBase { size_t NoDummyInputSize() const; + ir::Node *Node() { return node_; } + + const std::map + &GetDeviceContexts() const { + return dev_ctxes_; + } + protected: void RunAndRecordEvent(const std::function &callback); diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h new file mode 100644 index 0000000000..b76fc646c2 --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -0,0 +1,123 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace framework { +namespace details { + +using ReferenceCountMap = std::unordered_map; +using AtomicReferenceCountMap = + std::unordered_map>; +using DeviceReferenceCountMap = + std::unordered_map>; +using AtomicDeviceReferenceCountMap = + std::unordered_map>; +using DeviceGarbageCollectorMap = + std::unordered_map>>; + +class ReferenceCountOpHandle : public OpHandleBase { + public: + ReferenceCountOpHandle(ir::Node *node, const Scope *scope, + const platform::CUDAPlace &place, + const std::vector &var_names, + GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts) + : OpHandleBase(node), + scope_(scope), + var_names_(var_names), + gc_(gc), + ref_cnts_(ref_cnts) { + dev_ctx_ = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + if (IsStreamGarabageCollector()) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + } + + ~ReferenceCountOpHandle() { + if (IsStreamGarabageCollector()) { + auto gpu_place = boost::get(dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); + PADDLE_ENFORCE(cudaEventDestroy(event_)); + } + } + + std::string Name() const override { return "reference_count"; } + + // protected: + void RunImpl() override { + auto *exec_scope_ = scope_->FindVar(kLocalExecScopeName)->Get(); + std::vector tensors; + for (auto &name : var_names_) { + auto it = ref_cnts_->find(name); + if (it == ref_cnts_->end()) continue; + + auto *var = exec_scope_->FindVar(name); + if (var == nullptr || !var->IsType()) continue; + + if (it->second.fetch_sub(1) <= 1) { + tensors.emplace_back(var->GetMutable()); + } + } + + if (!tensors.empty()) { + ClearTensors(tensors); + } + } + + private: + void ClearTensors(const std::vector &tensors) const { + auto *gc = dynamic_cast *>(gc_); + if (gc != nullptr) { + auto compute_stream = dev_ctx_->stream(); + auto callback_stream = gc->stream(); + auto callback_func = [=]() { + PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); + PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); + }; + gc_->Add(tensors, callback_func); + } else { + gc_->Add(tensors); + } + } + + bool IsStreamGarabageCollector() const { + return dynamic_cast *>(gc_) != nullptr; + } + + const Scope *scope_; + platform::CUDADeviceContext *dev_ctx_; + std::vector var_names_; + GarbageCollector *gc_; // not own + AtomicReferenceCountMap *ref_cnts_; // not own + cudaEvent_t event_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc new file mode 100644 index 0000000000..892e6ea48a --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/reference_count_pass.h" + +namespace paddle { +namespace framework { +namespace details { + +std::unique_ptr ReferenceCountPass::ApplyImpl( + std::unique_ptr graph) const { + auto &ref_cnts = Get(kGlobalReferenceCount); + auto &cur_ref_cnts = Get(kCurReferenceCount); + auto &gcs = Get(kGarbageCollector); + + // It is not easy to find the right reference counts of varaibles in graph + // Step 1: Find all variables in computation ops + // Step 2: Find all variables in non-computation ops which refers to variables + // in computation ops + std::unordered_set names; + auto get_ref_cnts_from_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + std::vector var_names_in_op; + auto *compute_op = dynamic_cast(op.get()); + if (compute_op == nullptr || + !platform::is_gpu_place(compute_op->GetPlace())) + return var_names_in_op; + auto place = boost::get(compute_op->GetPlace()); + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + if (!platform::is_gpu_place(var_handle->place_) || + boost::get(var_handle->place_) != place) + continue; + + VarDesc *var_desc = var_handle->Node()->Var(); + auto var_name = var_handle->Node()->Name(); + + // This is wierd but there is really some variables without var_desc + // in computation_op + if (var_desc == nullptr) { + if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) + continue; + } else { + if (var_desc->Persistable() || + var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR) + continue; + } + + // compute op only runs in one device + if (ref_cnts[place.device]->count(var_name)) + ++(*ref_cnts[place.device])[var_name]; + else + (*ref_cnts[place.device])[var_name] = 1; + + names.insert(var_name); + var_names_in_op.push_back(var_name); + } + return var_names_in_op; + }; + + auto update_ref_cnts_from_non_compute_op = [&]( + const std::unique_ptr &op, + const std::vector &vars) { + if (dynamic_cast(op.get()) != nullptr) return; + for (VarHandleBase *var_handle_base : vars) { + auto *var_handle = dynamic_cast(var_handle_base); + if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; + + auto var_name = var_handle->Node()->Name(); + auto var_place = var_handle->place_; + if (!platform::is_gpu_place(var_place)) continue; + auto place = boost::get(var_place); + if (names.count(var_name) == 0) continue; + if (ref_cnts.count(place.device) && + ref_cnts[place.device]->count(var_name)) { + ++(*ref_cnts[place.device])[var_name]; + } + } + }; + + std::unordered_map + compute_ref_cnt_map; + auto &all_ops = graph->Get(kGraphOps); + for (auto &op : all_ops) { + auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); + auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); + if (in_var_names.empty() && out_var_names.empty()) continue; + in_var_names.insert(in_var_names.end(), out_var_names.begin(), + out_var_names.end()); + auto *compute_op = dynamic_cast(op.get()); + auto place = boost::get(compute_op->GetPlace()); + ir::Node *ref_cnt_node = + graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); + auto *ref_cnt_handle = new ReferenceCountOpHandle( + ref_cnt_node, compute_op->GetScope(), place, in_var_names, + gcs[place.device].get(), cur_ref_cnts[place.device].get()); + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + compute_op->AddOutput(dep_var); + ref_cnt_handle->AddInput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + compute_ref_cnt_map[compute_op] = ref_cnt_handle; + } + + for (auto &op : all_ops) { + update_ref_cnts_from_non_compute_op(op, op->Inputs()); + update_ref_cnts_from_non_compute_op(op, op->Outputs()); + } + + std::vector> new_all_ops; + new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); + for (auto &op : all_ops) { + auto it = compute_ref_cnt_map.find(op.get()); + if (it != compute_ref_cnt_map.end()) { + new_all_ops.emplace_back(std::move(op)); + new_all_ops.emplace_back(std::unique_ptr(it->second)); + } else { + new_all_ops.emplace_back(std::move(op)); + } + } + + all_ops.swap(new_all_ops); + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reference_count_pass, + paddle::framework::details::ReferenceCountPass) + .RequirePassAttr(paddle::framework::details::kGlobalReferenceCount) + .RequirePassAttr(paddle::framework::details::kCurReferenceCount) + .RequirePassAttr(paddle::framework::details::kGarbageCollector); diff --git a/paddle/fluid/framework/details/reference_count_pass.h b/paddle/fluid/framework/details/reference_count_pass.h new file mode 100644 index 0000000000..7081280b06 --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +constexpr char kGlobalReferenceCount[] = "reference_count"; +constexpr char kCurReferenceCount[] = "current_reference_count"; +constexpr char kGarbageCollector[] = "garbage_collector"; + +class ReferenceCountPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index eb4e7ec52f..51e840ffa6 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -16,6 +16,10 @@ #include #include #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_op_handle.h" +#endif namespace paddle { namespace framework { @@ -56,12 +60,28 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( auto fetch_data = underlying_executor_->Run(fetch_tensors); drop_scope_counter_ += 1; + +#ifdef PADDLE_WITH_CUDA + const std::string gc_name = "garbage_collector"; + DeviceGarbageCollectorMap *gc = + Graph().Has(gc_name) ? &(Graph().Get(gc_name)) + : nullptr; +#endif + if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ = 0; // Wait All computational streams for (auto p : places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); +#ifdef PADDLE_WITH_CUDA + if (gc != nullptr && platform::is_gpu_place(p)) { + auto gpu_place = boost::get(p); + auto &gc_at_place = gc->at(gpu_place.device); + gc_at_place->Wait(); + gc_at_place->Reset(); + } +#endif } for (auto &scope : local_scopes_) { auto &local_scope = diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 84f67fafa1..6868f639a0 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -37,7 +37,9 @@ int kProgramId = -1; ExecutorPrepareContext::ExecutorPrepareContext( const framework::ProgramDesc& prog, size_t block_id) - : prog_(prog), block_id_(block_id) {} + : prog_(prog), + block_id_(block_id), + ref_cnts_(GetNonPersistableReferenceCount(prog, block_id)) {} ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; @@ -335,20 +337,84 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, CreateVariables(ctx->prog_, local_scope, ctx->block_id_); } + std::shared_ptr> erase_tensors( + new std::vector()); + int64_t max_memory_size = GetEagerDeletionThreshold(); + + std::unique_ptr> gc; + if (max_memory_size >= 0) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place_)) { + gc.reset(new DefaultStreamGarbageCollector( + boost::get(place_), max_memory_size)); + } else { +#endif + gc.reset(new CPUGarbageCollector( + boost::get(place_), max_memory_size)); +#ifdef PADDLE_WITH_CUDA + } +#endif + } + for (auto& op : ctx->ops_) { VLOG(4) << place_ << " " << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); - // NOTE! Please do not delete this line, it's usefull because the debug - // string before and after op.run are different, after run the output - // will have right shape which is usefull for debug. - VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); + +#ifdef PADDLE_WITH_CUDA + if (gc != nullptr) { + std::vector erase_vars; + for (auto& input : op->Inputs()) { + for (auto& input_name : input.second) { + auto it = ctx->ref_cnts_.find(input_name); + if (it == ctx->ref_cnts_.end()) continue; + if (it->second == 1) { // should delete it + erase_vars.emplace_back(input_name); + ctx->ref_cnts_.erase(input_name); + } else { + --(it->second); + } + } + } + + for (auto& output : op->Outputs()) { + for (auto& output_name : output.second) { + auto it = ctx->ref_cnts_.find(output_name); + if (it == ctx->ref_cnts_.end()) continue; + if (it->second == 1) { + erase_vars.emplace_back(output_name); + ctx->ref_cnts_.erase(output_name); + } else { + --(it->second); + } + } + } + + if (!erase_vars.empty()) { + std::vector erase_tensors; + for (auto& name : erase_vars) { + auto* var = local_scope->FindVar(name); + if (var == nullptr) continue; + if (var->IsType()) { + auto* tensor = var->GetMutable(); + erase_tensors.push_back(tensor); + } + } + if (!erase_tensors.empty()) gc->Add(erase_tensors); + } + } +#endif if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " << memory::memory_usage(place_); } } - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + + if (gc != nullptr) + gc->Wait(); + else + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + if (local_scope != scope) { scope->DeleteScope(local_scope); } else { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 563a4b2bb6..81d83ecea5 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -27,6 +28,48 @@ namespace paddle { namespace framework { extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); +int64_t GetEagerDeletionThreshold(); + +template +std::unordered_map GetNonPersistableReferenceCount( + const ProgramDesc& prog, size_t block_id) { + auto& block = prog.Block(block_id); + std::unordered_set ignored_vars; + std::unordered_map ref_cnts; + + for (auto var_desc : block.AllVars()) { + auto type = var_desc->Proto()->type().type(); + if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) { + ignored_vars.insert(var_desc->Name()); // ignore persistable vars + } + } + + for (auto op_desc : block.AllOps()) { + for (auto& input : op_desc->Inputs()) { + for (auto& input_name : input.second) { + if (!ignored_vars.count(input_name)) { + if (ref_cnts.count(input_name)) + ++ref_cnts[input_name]; + else + ref_cnts[input_name] = 1; + } + } + } + + for (auto& output : op_desc->Outputs()) { + for (auto output_name : output.second) { + if (!ignored_vars.count(output_name)) { + if (ref_cnts.count(output_name)) + ++ref_cnts[output_name]; + else + ref_cnts[output_name] = 1; + } + } + } + } + return ref_cnts; +} + struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ~ExecutorPrepareContext(); @@ -34,6 +77,8 @@ struct ExecutorPrepareContext { const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; + + std::unordered_map ref_cnts_; }; class Executor { diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h new file mode 100644 index 0000000000..b403252c97 --- /dev/null +++ b/paddle/fluid/framework/garbage_collector.h @@ -0,0 +1,163 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include // NOLINT +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { + +// T should have memory_size() and clear() method +template +class GarbageCollector { + public: + GarbageCollector(const platform::Place &place, size_t max_memory_size) + : max_memory_size_(std::max(max_memory_size, static_cast(1))) { + garbages_.reset(new std::deque()); + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); + } + + virtual ~GarbageCollector() {} + + void Reset() { + std::lock_guard guard(mutex_); + garbages_.reset(new std::deque()); + cur_memory_size_ = 0; + } + + template + void Add(const Container &objs) { + Add(objs, []() {}); + } + + template + void Add(const Container &objs, Callback &&callback) { + std::shared_ptr> clear_deque; + { + std::lock_guard guard(mutex_); + for (auto *obj : objs) { + garbages_->push_back(obj); + cur_memory_size_ += obj->memory_size(); + } + if (cur_memory_size_ >= max_memory_size_) { + cur_memory_size_ = 0; + clear_deque = garbages_; + garbages_.reset(new std::deque()); + } + } + + if (clear_deque != nullptr) { + callback(); + ClearCallback([=]() { + for (auto *obj : *clear_deque) obj->clear(); + }); + } + } + + virtual void Wait() const {} + + protected: + virtual void ClearCallback(const std::function &callback) = 0; + + platform::DeviceContext *dev_ctx_; + std::shared_ptr> garbages_; + mutable std::mutex mutex_; + const size_t max_memory_size_; + size_t cur_memory_size_ = 0; +}; + +template +class CPUGarbageCollector : public GarbageCollector { + public: + CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + protected: + void ClearCallback(const std::function &callback) override { + callback(); + } +}; + +#ifdef PADDLE_WITH_CUDA +template +class DefaultStreamGarbageCollector : public GarbageCollector { + public: + DefaultStreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + + cudaStream_t stream() const { + return static_cast(this->dev_ctx_) + ->stream(); + } + + void Wait() const override { + this->dev_ctx_->Wait(); + static_cast(this->dev_ctx_) + ->WaitStreamCallback(); + } + + protected: + void ClearCallback(const std::function &callback) override { + static_cast(this->dev_ctx_) + ->AddStreamCallback(callback); + } +}; + +template +class StreamGarbageCollector : public GarbageCollector { + public: + StreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) { + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + callback_manager_.reset(new platform::StreamCallbackManager(stream_)); + } + + ~StreamGarbageCollector() { + auto place = boost::get(this->dev_ctx_->GetPlace()); + PADDLE_ENFORCE(cudaSetDevice(place.device)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + } + + void Wait() const override { + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + std::lock_guard guard(this->mutex_); + callback_manager_->Wait(); + } + + cudaStream_t stream() const { return stream_; } + + protected: + void ClearCallback(const std::function &callback) override { + std::lock_guard guard(this->mutex_); + callback_manager_->AddCallback(callback); + } + + private: + cudaStream_t stream_; + std::unique_ptr callback_manager_; +}; +#endif + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h new file mode 100644 index 0000000000..ab687e760a --- /dev/null +++ b/paddle/fluid/framework/ir/graph.h @@ -0,0 +1,183 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/variant.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * The graph is a Directed Acyclic Single Static Assignment Graph. + * + * In more detail, the following properties must hold: + * + * The graph shouldn't contain cycle. Each node is a black-box to the graph + * so the node itself could be a loop operator. + * + * Each Variable-type node has only one input (thus single static assignment). + * + * The output/input of operator is variable and the output/input of variable + * is operator. + * + * The following data harzards in Program are addressed in the Graph: + * + * Write-After-Read + * a = op1(x) + * x = op2(b) + * A control-dependency connection is created bettwen op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Write-After-Write + * x = op1(a) + * x = op2(b) + * A control-dependency connection is created between op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Other properties currently hold, but is not enforced yet: + * + * Variable-type node (not control dep) with the same variable name share + * the same underlying VarDesc. + */ +class Graph { + public: + explicit Graph(const ProgramDesc &program); + + virtual ~Graph() { + for (auto &attr : attrs_) { + attr_dels_[attr.first](); + } + attrs_.clear(); + attr_dels_.clear(); + } + + bool Has(const std::string &attr_name) const { + return attrs_.find(attr_name) != attrs_.end(); + } + + template + AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", + attr_name); + return *boost::any_cast(attrs_.at(attr_name)); + } + + template + void Set(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; + } + + template + void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = []() {}; + } + + const std::unordered_set &Nodes() const { return node_set_; } + + // Create a normal variable with non-null VarDesc. + ir::Node *CreateVarNode(VarDesc *var_desc) { + PADDLE_ENFORCE(var_desc); + return AddNode(new ir::Node(var_desc)); + } + + // Create a normal runnable operator with OpDesc. + ir::Node *CreateOpNode(OpDesc *op_desc) { + PADDLE_ENFORCE(op_desc); + return AddNode(new ir::Node(op_desc)); + } + + // Create a control dependency var that connects 2 operations. The + // var doesn't hold any data. Other than that, it's no different from + // other var, considering dependency analysis. + ir::Node *CreateControlDepVar() { + // TODO(panyx0718): control var name should be really unique. + const std::string name = string::Sprintf( + "%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); + return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); + } + + // A more free style way of creating a graph node. Mostly use for test + // or "copy" from another node. Avoid using it if possible. + ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { + return AddNode(new ir::Node(name, type)); + } + + // Clear all node information of the graph and return the ownership of the + // nodes. + std::vector> ReleaseNodes() { + std::vector> ret; + for (auto &n : nodes_) { + ret.emplace_back(n.second.release()); + } + nodes_.clear(); + node_set_.clear(); + return ret; + } + + void RemoveNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); + node_set_.erase(node); + nodes_.erase(node); + } + + // NOTE low performance, but simple and secure. + Node *RetriveNode(int id) { + for (auto &node : nodes_) { + if (node.second->id() == id) { + return node.second.get(); + } + } + return nullptr; + } + + private: + // This method takes ownership of `node`. + ir::Node *AddNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); + nodes_[node].reset(node); + node_set_.insert(node); + return node; + } + + // NOTE: program_ shouldn't be exposed to user. + const ProgramDesc program_; + std::map attrs_; + std::map> attr_dels_; + std::map> nodes_; + std::unordered_set node_set_; +}; + +bool IsControlDepVar(const ir::Node &var); +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b53a6f43fb..5a19e7f1bf 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -19,9 +19,15 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/platform/nccl_helper.h" #endif +#include "paddle/fluid/framework/details/all_reduce_op_handle.h" +#include "paddle/fluid/framework/details/broadcast_op_handle.h" +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/reduce_op_handle.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" @@ -115,17 +121,39 @@ ParallelExecutor::ParallelExecutor( build_strategy); if (member_->use_cuda_) { #ifdef PADDLE_WITH_CUDA - builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy, + member_->nccl_ctxs_.get()); + + auto max_memory_size = GetEagerDeletionThreshold(); + if (max_memory_size >= 0) { + for (auto &place : member_->places_) { + if (!platform::is_gpu_place(place)) continue; + auto gpu_place = boost::get(place); + if (gcs_[gpu_place.device] == nullptr) { + ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap()); + cur_ref_cnts_[gpu_place.device].reset( + new details::AtomicReferenceCountMap()); + gcs_[gpu_place.device].reset( + new StreamGarbageCollector(gpu_place, max_memory_size)); + } + } + if (!gcs_.empty()) { + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_); + graph = ref_cnt_pass->Apply(std::move(graph)); + graph->SetNotOwned("garbage_collector", &gcs_); + } + } #else PADDLE_THROW("Not compiled with CUDA"); #endif } - builder_ = builder_factory.Create(); - member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, - builder_->Build(main_program))); - member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); @@ -216,6 +244,11 @@ void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { platform::RecordBlock b(0); +#ifdef PADDLE_WITH_CUDA + if (!gcs_.empty()) { + ResetReferenceCount(); + } +#endif auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; @@ -265,3 +298,11 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle + +USE_PASS(graph_viz_pass); +USE_PASS(multi_devices_pass); +USE_PASS(multi_devices_check_pass); +USE_PASS(multi_devices_print_pass); +#ifdef PADDLE_WITH_CUDA +USE_PASS(reference_count_pass); +#endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 058f83f07c..2aa438e320 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -15,7 +15,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include #include "paddle/fluid/framework/details/execution_strategy.h" @@ -70,7 +72,23 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; - std::unique_ptr builder_; + +#ifdef PADDLE_WITH_CUDA + // ref_cnts_ is only initialized when ParallelExecutor constructs, and then + // keeps unchanged + // Before each iteration, cur_ref_cnts_ is reset to ref_cnts_ + details::DeviceReferenceCountMap ref_cnts_; + details::AtomicDeviceReferenceCountMap cur_ref_cnts_; + details::DeviceGarbageCollectorMap gcs_; + + void ResetReferenceCount() { + for (auto &pair1 : ref_cnts_) { + for (auto &pair2 : *(pair1.second)) { + (*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second; + } + } + } +#endif }; } // namespace framework diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 50f374e370..caea191cb3 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -31,9 +31,21 @@ DEFINE_bool( "Delete local scope eagerly. It will reduce GPU memory usage but " "slow down the destruction of variables.(around 1% performance harm)"); +DEFINE_double( + eager_delete_tensor_GB, -1.0, + "Memory size threshold (GB) when the garbage collector clear tensors." + "Disabled when this value is less than 0"); + namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold() { + return FLAGS_eager_delete_tensor_GB < 0 + ? -1 + : static_cast(FLAGS_eager_delete_tensor_GB * + (static_cast(1) << 30)); +} + Scope::~Scope() { DropKids(); } Scope& Scope::NewScope() const { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index e246241c0a..47d040240a 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -26,6 +26,8 @@ limitations under the License. */ namespace paddle { namespace framework { +int64_t GetEagerDeletionThreshold(); + class Scope; /** diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index ef224d68f1..775c01765c 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -149,6 +149,8 @@ class Tensor { void set_layout(const DataLayout layout) { layout_ = layout; } + void clear() { holder_ = nullptr; } + private: /** * @note Placeholder hides type T, so it doesn't appear as a template diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 20037d0764..ac9bf9a505 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -45,8 +45,8 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies -cc_library(device_context SRCS device_context.cc init.cc DEPS malloc - place eigen3 stringpiece cpu_helper ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) +cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) cc_test(init_test SRCS init_test.cc DEPS device_context) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 2cc26da013..a57ee2d8f5 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -159,11 +159,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { } else { cudnn_handle_ = nullptr; } + + callback_manager_.reset(new StreamCallbackManager(stream_)); } CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); + WaitStreamCallback(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); if (cudnn_handle_ != nullptr) { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 88e0383146..0fb5338368 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -31,8 +31,13 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/stream_callback_manager.h" +#endif #include "unsupported/Eigen/CXX11/Tensor" +DECLARE_bool(clear_gpu_memory_when_unused); + namespace paddle { namespace platform { @@ -106,6 +111,17 @@ class CUDADeviceContext : public DeviceContext { PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } + template + void AddStreamCallback(Callback&& callback) const { + std::lock_guard guard(callback_mtx_); + callback_manager_->AddCallback(callback); + } + + void WaitStreamCallback() const { + std::lock_guard guard(callback_mtx_); + callback_manager_->Wait(); + } + private: CUDAPlace place_; @@ -119,7 +135,12 @@ class CUDADeviceContext : public DeviceContext { int multi_process; int max_threads_per_mp; - std::mutex mtx_; + mutable std::mutex mtx_; + + // This lock is only used by callback + // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes + mutable std::mutex callback_mtx_; + std::unique_ptr callback_manager_; }; template <> diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h new file mode 100644 index 0000000000..6c984065aa --- /dev/null +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -0,0 +1,82 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "ThreadPool.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +using StreamCallback = std::function; + +class StreamCallbackManager; + +struct StreamCallbackContext { + template + inline StreamCallbackContext(const StreamCallbackManager *manager, + Callback &&callback) + : manager_(manager), callback_(callback) {} + + const StreamCallbackManager *manager_; // do not own + StreamCallback callback_; +}; + +class StreamCallbackManager { + public: + explicit inline StreamCallbackManager(cudaStream_t stream = nullptr) + : stream_(stream), thread_pool_(new ThreadPool(1)) {} + + template + inline void AddCallback(Callback &&callback) const { + AddCallbackWithStreamAndErrorInfo( + [=](cudaStream_t, cudaError_t) { callback(); }); + } + + template + inline void AddCallbackWithStreamAndErrorInfo(Callback &&callback) const { + auto *stream_callback_context = new StreamCallbackContext(this, callback); + PADDLE_ENFORCE(cudaStreamAddCallback( + stream_, StreamCallbackManager::StreamCallbackFunc, + stream_callback_context, 0)); + } + + void Wait() const { thread_pool_.reset(new ThreadPool(1)); } + + private: + const cudaStream_t stream_; + mutable std::unique_ptr thread_pool_; + + // cudaStreamCallback cannot call CUDA API inside, so we have to use + // thread_pool here + static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, + cudaError_t status, + void *user_data) { + auto *callback_context_ptr = + reinterpret_cast(user_data); + callback_context_ptr->manager_->thread_pool_->enqueue([=]() { + std::unique_ptr callback_context( + callback_context_ptr); + callback_context->callback_(stream, status); + }); + } +}; + +} // namespace platform +} // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 3034c1a087..74b268aede 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -117,9 +117,19 @@ def __bootstrap__(): os.environ['OMP_NUM_THREADS'] = str(num_threads) read_env_flags = [ - 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', - 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', - 'init_allocated_mem' + 'use_pinned_memory', + 'check_nan_inf', + 'benchmark', + 'warpctc_dir', + 'eager_delete_scope', + 'use_mkldnn', + 'initial_cpu_memory_in_mb', + 'init_allocated_mem', + 'free_idle_memory', + 'paddle_num_threads', + "dist_threadpool_size", + 'cpu_deterministic', + 'eager_delete_tensor_GB', ] if core.is_compiled_with_cuda(): read_env_flags += [ From 60afef1e858ada7709e95b500f36a995ddd29469 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Sun, 16 Sep 2018 15:00:47 +0800 Subject: [PATCH 05/11] fix code style --- paddle/fluid/memory/detail/buddy_allocator.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index 3c961e5040..1af422feb8 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -162,8 +162,8 @@ void BuddyAllocator::Free(void* p) { } size_t BuddyAllocator::Used() { return total_used_; } -size_t BuddyAllocator::GetMinChunkSize() {return min_chunk_size_;}; -size_t BuddyAllocator::GetMaxChunkSize() {return max_chunk_size_;}; +size_t BuddyAllocator::GetMinChunkSize() { return min_chunk_size_; } +size_t BuddyAllocator::GetMaxChunkSize() { return max_chunk_size_; } void* BuddyAllocator::SystemAlloc(size_t size) { size_t index = 0; From 612e1a31554b4cc24eea61ff257014a975b47929 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Sat, 15 Sep 2018 15:18:10 +0000 Subject: [PATCH 06/11] modification --- .../framework/details/computation_op_handle.h | 2 -- paddle/fluid/framework/details/op_handle_base.h | 5 ----- .../details/reference_count_op_handle.h | 10 +++++----- .../framework/details/reference_count_pass.cc | 8 +++----- paddle/fluid/framework/executor.cc | 17 ++++++++--------- paddle/fluid/framework/executor.h | 2 -- paddle/fluid/framework/parallel_executor.cc | 1 - paddle/fluid/framework/parallel_executor.h | 4 ++++ paddle/fluid/framework/scope.cc | 6 +++--- paddle/fluid/platform/device_context.h | 2 -- python/paddle/fluid/__init__.py | 2 +- 11 files changed, 24 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 9a330749ea..e98f1ab148 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -23,8 +23,6 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/framework/details/reference_count_op_handle.h" - namespace paddle { namespace framework { namespace details { diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index d4e2c44482..9fbefabc84 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -89,11 +89,6 @@ class OpHandleBase { ir::Node *Node() { return node_; } - const std::map - &GetDeviceContexts() const { - return dev_ctxes_; - } - protected: void RunAndRecordEvent(const std::function &callback); diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h index b76fc646c2..71db8d952f 100644 --- a/paddle/fluid/framework/details/reference_count_op_handle.h +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -69,15 +69,15 @@ class ReferenceCountOpHandle : public OpHandleBase { std::string Name() const override { return "reference_count"; } - // protected: + protected: void RunImpl() override { - auto *exec_scope_ = scope_->FindVar(kLocalExecScopeName)->Get(); + auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); std::vector tensors; for (auto &name : var_names_) { auto it = ref_cnts_->find(name); if (it == ref_cnts_->end()) continue; - auto *var = exec_scope_->FindVar(name); + auto *var = exec_scope->FindVar(name); if (var == nullptr || !var->IsType()) continue; if (it->second.fetch_sub(1) <= 1) { @@ -91,8 +91,8 @@ class ReferenceCountOpHandle : public OpHandleBase { } private: - void ClearTensors(const std::vector &tensors) const { - auto *gc = dynamic_cast *>(gc_); + void ClearTensors(const std::vector &tensors) { + auto *gc = dynamic_cast *>(gc_); if (gc != nullptr) { auto compute_stream = dev_ctx_->stream(); auto callback_stream = gc->stream(); diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 892e6ea48a..344754d5a1 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -128,12 +128,10 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( std::vector> new_all_ops; new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); for (auto &op : all_ops) { - auto it = compute_ref_cnt_map.find(op.get()); + new_all_ops.emplace_back(std::move(op)); + auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); if (it != compute_ref_cnt_map.end()) { - new_all_ops.emplace_back(std::move(op)); - new_all_ops.emplace_back(std::unique_ptr(it->second)); - } else { - new_all_ops.emplace_back(std::move(op)); + new_all_ops.emplace_back(it->second); } } diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index fd58de28af..650d9086d4 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -37,9 +37,11 @@ int kProgramId = -1; ExecutorPrepareContext::ExecutorPrepareContext( const framework::ProgramDesc& prog, size_t block_id) - : prog_(prog), - block_id_(block_id), - ref_cnts_(GetNonPersistableReferenceCount(prog, block_id)) {} + : prog_(prog), block_id_(block_id) { + if (GetEagerDeletionThreshold() >= 0) { + ref_cnts_ = GetNonPersistableReferenceCount(prog_, block_id_); + } +} ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; @@ -331,8 +333,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, CreateVariables(ctx->prog_, local_scope, ctx->block_id_); } - std::shared_ptr> erase_tensors( - new std::vector()); int64_t max_memory_size = GetEagerDeletionThreshold(); std::unique_ptr> gc; @@ -353,7 +353,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, for (auto& op : ctx->ops_) { op->Run(*local_scope, place_); -#ifdef PADDLE_WITH_CUDA if (gc != nullptr) { std::vector erase_vars; for (auto& input : op->Inputs()) { @@ -395,7 +394,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, if (!erase_tensors.empty()) gc->Add(erase_tensors); } } -#endif if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " @@ -403,10 +401,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } - if (gc != nullptr) + if (gc != nullptr) { gc->Wait(); - else + } else { platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } if (local_scope != scope) { scope->DeleteScope(local_scope); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 122bafedce..b746268760 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -28,8 +28,6 @@ namespace paddle { namespace framework { extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); -int64_t GetEagerDeletionThreshold(); - template std::unordered_map GetNonPersistableReferenceCount( const ProgramDesc& prog, size_t block_id) { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 880521f29e..ae393d66a3 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_viz_pass.h" #ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/platform/nccl_helper.h" #endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index a0f66c3f8f..88e2078454 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -29,6 +29,10 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/reference_count_pass.h" +#endif + namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index ece9a69a99..1a727a2c8c 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -32,7 +32,7 @@ DEFINE_bool( "slow down the destruction of variables.(around 1% performance harm)"); DEFINE_double( - eager_delete_tensor_GB, -1.0, + eager_delete_tensor_gb, -1.0, "Memory size threshold (GB) when the garbage collector clear tensors." "Disabled when this value is less than 0"); @@ -40,9 +40,9 @@ namespace paddle { namespace framework { int64_t GetEagerDeletionThreshold() { - return FLAGS_eager_delete_tensor_GB < 0 + return FLAGS_eager_delete_tensor_gb < 0 ? -1 - : static_cast(FLAGS_eager_delete_tensor_GB * + : static_cast(FLAGS_eager_delete_tensor_gb * (static_cast(1) << 30)); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index c3b092b2a5..7953919515 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -36,8 +36,6 @@ limitations under the License. */ #endif #include "unsupported/Eigen/CXX11/Tensor" -DECLARE_bool(clear_gpu_memory_when_unused); - namespace paddle { namespace platform { diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index e4d7575ca4..1ca2ac2ddc 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -122,7 +122,7 @@ def __bootstrap__(): 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', - "dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_GB' + "dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_gb' ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline') From 006c9246f311c455a9b350bc7a36905a14806487 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 17 Sep 2018 12:47:25 +0800 Subject: [PATCH 07/11] doc fix --- doc/fluid/dev/releasing_process_en.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/dev/releasing_process_en.md b/doc/fluid/dev/releasing_process_en.md index b810dc941d..00650946ff 100644 --- a/doc/fluid/dev/releasing_process_en.md +++ b/doc/fluid/dev/releasing_process_en.md @@ -1,6 +1,6 @@ # PaddlePaddle Releasing Process -PaddlePaddle manages its branches using "git-flow branching model", and [Semantic Versioning](http://semver.org/) as it's version number semantics. +PaddlePaddle manages its branches using Trunk Based Development, and [Semantic Versioning](http://semver.org/) as it's version number semantics. Each time we release a new PaddlePaddle version, we should follow the below steps: From d40402f9b7e9bbda72d6636273a436df02f2ea05 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Mon, 17 Sep 2018 05:37:51 +0000 Subject: [PATCH 08/11] add dropout and sigmoid op converter --- paddle/fluid/inference/analysis/analyzer.cc | 5 +- .../api/api_tensorrt_subgraph_engine.cc | 10 +++ .../inference/tensorrt/convert/CMakeLists.txt | 6 +- .../tensorrt/convert/activation_op.cc | 48 +++++++++++-- .../inference/tensorrt/convert/dropout_op.cc | 71 +++++++++++++++++++ .../tensorrt/convert/test_activation_op.cc | 20 ++++-- .../tensorrt/convert/test_dropout_op.cc | 58 +++++++++++++++ 7 files changed, 202 insertions(+), 16 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/dropout_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 6dc39cae05..8a8aeb5e09 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -69,8 +69,9 @@ class DfgPassManagerImpl final : public DfgPassManager { if (FLAGS_IA_enable_tensorrt_subgraph_engine) { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( - {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", - "depthwise_conv2d", "batch_norm", "concat"}); + {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", + "depthwise_conv2d", "batch_norm", "concat", "tanh", + "elementwise_add", "dropout"}); if (!node->IsFunction()) return false; const auto* func = static_cast(node); diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index abee375313..d9d6e139b8 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -153,11 +153,21 @@ CreatePaddlePredictor( } // namespace paddle USE_TRT_CONVERTER(elementwise_add_weight); +USE_TRT_CONVERTER(elementwise_add_tensor); +USE_TRT_CONVERTER(elementwise_sub_tensor); +USE_TRT_CONVERTER(elementwise_div_tensor); +USE_TRT_CONVERTER(elementwise_mul_tensor); +USE_TRT_CONVERTER(elementwise_max_tensor); +USE_TRT_CONVERTER(elementwise_min_tensor); +USE_TRT_CONVERTER(elementwise_pow_tensor); USE_TRT_CONVERTER(mul); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); +USE_TRT_CONVERTER(sigmoid); +USE_TRT_CONVERTER(tanh); USE_TRT_CONVERTER(fc); USE_TRT_CONVERTER(pool2d); USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); +USE_TRT_CONVERTER(dropout); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 9d7be2d03c..fac1babf6e 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -24,6 +24,8 @@ nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) - nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) + +nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 8168cdff1b..e73c5bbf57 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -19,23 +19,31 @@ namespace paddle { namespace inference { namespace tensorrt { -class ReluOpConverter : public OpConverter { +class ActivationOpConverter : public OpConverter { public: - ReluOpConverter() {} + ActivationOpConverter() {} void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { // Here the two nullptr looks strange, that's because the // framework::OpDesc's constructor is strange. framework::OpDesc op_desc(op, nullptr); - LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " - "type is Relu"; + LOG(INFO) + << "convert a fluid Activation op to tensorrt activation layer whose " + "type is " + << op_type_; const nvinfer1::ITensor* input_tensor = engine_->GetITensor(op_desc.Input("X")[0]); + + auto op_pair = ops.find(op_type_); + if (op_pair == ops.end()) { + PADDLE_THROW("Wrong activation op type!"); + } + nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, Activation, *const_cast(input_tensor), - nvinfer1::ActivationType::kRELU); + op_pair->second); auto output_name = op_desc.Output("Out")[0]; - layer->setName(("relu (Output: " + output_name + ")").c_str()); + layer->setName((op_type_ + " (Output: " + output_name + ")").c_str()); layer->getOutput(0)->setName(output_name.c_str()); engine_->SetITensor(output_name, layer->getOutput(0)); if (test_mode) { // the test framework can not determine which is the @@ -43,6 +51,32 @@ class ReluOpConverter : public OpConverter { engine_->DeclareOutput(output_name); } } + + protected: + std::string op_type_; + static const std::unordered_map ops; +}; + +const std::unordered_map + ActivationOpConverter::ops = { + {"relu", nvinfer1::ActivationType::kRELU}, + {"sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"tanh", nvinfer1::ActivationType::kTANH}, +}; + +class ReluOpConverter : public ActivationOpConverter { + public: + ReluOpConverter() { op_type_ = "relu"; } +}; + +class SigmoidOpConverter : public ActivationOpConverter { + public: + SigmoidOpConverter() { op_type_ = "sigmoid"; } +}; + +class TanhOpConverter : public ActivationOpConverter { + public: + TanhOpConverter() { op_type_ = "tanh"; } }; } // namespace tensorrt @@ -50,3 +84,5 @@ class ReluOpConverter : public OpConverter { } // namespace paddle REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); +REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter); +REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc new file mode 100644 index 0000000000..9533ecbcfd --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * DropoutOp. This Layer doesn't has weights. + */ +class DropoutOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid dropout op to tensorrt dropout layer"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + float dropout_prob = boost::get(op_desc.GetAttr("dropout_prob")); + + platform::CPUPlace cpu_place; + std::unique_ptr weight_tensor( + new framework::LoDTensor()); + weight_tensor->Resize(framework::make_ddim({1})); + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + weight_data[0] = 1 - dropout_prob; + + TensorRTEngine::Weight scale_weights{ + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + weight_tensor->memory_size() / sizeof(float)}; + TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + + auto* layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *const_cast(input1), + nvinfer1::ScaleMode::kUNIFORM, shift_weights.get(), scale_weights.get(), + power_weights.get()); + + engine_->weight_map[op_desc.Output("Out").front() + "_dropout"] = + std::move(weight_tensor); + auto output_name = op_desc.Output("Out")[0]; + layer->setName(("dropout (Output: " + output_name + ")").c_str()); + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(dropout); +REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index e82762ea03..dd3dfb0bc7 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -20,18 +20,18 @@ namespace paddle { namespace inference { namespace tensorrt { -TEST(ReluOpConverter, main) { +void test_activation(std::string act_type) { framework::Scope scope; std::unordered_set parameters; TRTConvertValidation validator(10, parameters, scope, 1000); - validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6)); - validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6)); + validator.DeclInputVar("act-X", nvinfer1::Dims2(10, 6)); + validator.DeclOutputVar("act-Out", nvinfer1::Dims2(10, 6)); // Prepare Op description framework::OpDesc desc; - desc.SetType("relu"); - desc.SetInput("X", {"relu-X"}); - desc.SetOutput("Out", {"relu-Out"}); + desc.SetType(act_type); + desc.SetInput("X", {"act-X"}); + desc.SetOutput("Out", {"act-Out"}); LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); @@ -40,8 +40,16 @@ TEST(ReluOpConverter, main) { validator.Execute(5); } +TEST(ReluOpConverter, main) { test_activation("relu"); } + +TEST(SigmoidOpConverter, main) { test_activation("sigmoid"); } + +TEST(TanhOpConverter, main) { test_activation("tanh"); } + } // namespace tensorrt } // namespace inference } // namespace paddle USE_OP(relu); +USE_OP(sigmoid); +USE_OP(tanh); diff --git a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc new file mode 100644 index 0000000000..6b8e621b70 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(DropoutOpConverter, main) { + framework::Scope scope; + std::unordered_set parameters; + TRTConvertValidation validator(8, parameters, scope, 1000); + + std::vector tensor_shape{8, 10}; + validator.DeclInputVar("dropout-X", tensor_shape, + nvinfer1::DimsCHW(10, 1, 1)); + validator.DeclOutputVar("dropout-Out", nvinfer1::DimsCHW(10, 1, 1)); + validator.DeclOutputVar("mask-Out", nvinfer1::DimsCHW(10, 1, 1)); + + // Prepare Op description + framework::OpDesc desc; + int is_test = 1; + float dropout_prob = 0.4; + + desc.SetType("dropout"); + desc.SetInput("X", {"dropout-X"}); + desc.SetOutput("Mask", {"mask-Out"}); + desc.SetOutput("Out", {"dropout-Out"}); + desc.SetAttr("is_test", is_test); + desc.SetAttr("dropout_prob", dropout_prob); + + LOG(INFO) << "set OP"; + validator.SetOp(*desc.Proto()); + LOG(INFO) << "execute"; + + std::unordered_set neglected_output = {"mask-Out"}; + + validator.Execute(8, neglected_output); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(dropout); From 480c7c4ee300803ea9097266a654e72a47267143 Mon Sep 17 00:00:00 2001 From: yuyang Date: Mon, 17 Sep 2018 05:42:30 +0000 Subject: [PATCH 09/11] Fix sentiment dataset --- python/paddle/dataset/sentiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/dataset/sentiment.py b/python/paddle/dataset/sentiment.py index 22d867beea..8051acb881 100644 --- a/python/paddle/dataset/sentiment.py +++ b/python/paddle/dataset/sentiment.py @@ -67,7 +67,7 @@ def get_word_dict(): for field in movie_reviews.fileids(category): for words in movie_reviews.words(field): word_freq_dict[words] += 1 - words_sort_list = six.iteritems(word_freq_dict) + words_sort_list = list(six.iteritems(word_freq_dict)) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) for index, word in enumerate(words_sort_list): words_freq_sorted.append((word[0], index)) From 114eb17587dfffffa2c2443bb79d2ad140801ffb Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 17 Sep 2018 12:03:18 +0000 Subject: [PATCH 10/11] fix executor bug --- paddle/fluid/framework/executor.cc | 13 +++++++------ paddle/fluid/framework/executor.h | 3 +++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 650d9086d4..8d8042a056 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -337,6 +337,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, std::unique_ptr> gc; if (max_memory_size >= 0) { + ctx->ResetReferenceCount(); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { gc.reset(new DefaultStreamGarbageCollector( @@ -357,11 +358,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, std::vector erase_vars; for (auto& input : op->Inputs()) { for (auto& input_name : input.second) { - auto it = ctx->ref_cnts_.find(input_name); - if (it == ctx->ref_cnts_.end()) continue; + auto it = ctx->cur_ref_cnts_.find(input_name); + if (it == ctx->cur_ref_cnts_.end()) continue; if (it->second == 1) { // should delete it erase_vars.emplace_back(input_name); - ctx->ref_cnts_.erase(input_name); + ctx->cur_ref_cnts_.erase(input_name); } else { --(it->second); } @@ -370,11 +371,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, for (auto& output : op->Outputs()) { for (auto& output_name : output.second) { - auto it = ctx->ref_cnts_.find(output_name); - if (it == ctx->ref_cnts_.end()) continue; + auto it = ctx->cur_ref_cnts_.find(output_name); + if (it == ctx->cur_ref_cnts_.end()) continue; if (it->second == 1) { erase_vars.emplace_back(output_name); - ctx->ref_cnts_.erase(output_name); + ctx->cur_ref_cnts_.erase(output_name); } else { --(it->second); } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index b746268760..f0cc1338a8 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -72,11 +72,14 @@ struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ~ExecutorPrepareContext(); + void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; } + const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; std::unordered_map ref_cnts_; + std::unordered_map cur_ref_cnts_; }; class Executor { From 0c8c0d943f56fca95a037070c165d2a64686e3c3 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 18 Sep 2018 09:52:05 +0800 Subject: [PATCH 11/11] fix macunittest (#13434) --- .../fluid/operators/math/cpu_lstm_compute.cc | 72 +------------------ .../fluid/operators/math/cpu_lstm_compute.h | 61 ++++++++++++++-- .../fluid/tests/unittests/test_desc_clone.py | 9 ++- .../fluid/transpiler/details/program_utils.py | 11 ++- 4 files changed, 73 insertions(+), 80 deletions(-) diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc index f7c55c215b..58e6512021 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.cc +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,76 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cpu_lstm_compute.h" -#include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/platform/cpu_info.h" -#ifdef __AVX__ -#include -#endif namespace paddle { namespace operators { -namespace math { - -// TODO(TJ): ugly workaround, clean me -template -void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { - // gates: W_ch, W_ih, W_fh, W_oh - vec_sigmoid(24, gates + 8, gates + 8); - vec_tanh(8, gates, gates); - const T *i = gates + 8, *f = gates + 16, *o = gates + 24; - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int d = 0; d < 8; ++d) { - // C_t = C_t-1 * fgated + cand_gated * igated - ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; - // H_t = act_cell(C_t) * ogated - T tmp = ct[d] * 2; - tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vec_exp(1, &tmp, &tmp); - tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); - ht[d] = tmp * o[d]; - } -} - -#ifdef __AVX__ -namespace detail { -namespace forward { -namespace avx { -__m256 Sigmoid(const __m256 a); -__m256 Tanh(const __m256 a); -} // namespace avx -} // namespace forward -} // namespace detail - -template <> -void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht) { - namespace act = detail::forward::avx; - // gates: W_ch, W_ih, W_fh, W_oh - __m256 c, i, f, o; - c = _mm256_loadu_ps(gates); - i = _mm256_loadu_ps(gates + 8); - f = _mm256_loadu_ps(gates + 16); - o = _mm256_loadu_ps(gates + 24); - - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); - i = _mm256_loadu_ps(ct_1); - f = _mm256_mul_ps(i, act::Sigmoid(f)); - f = _mm256_add_ps(c, f); - _mm256_storeu_ps(ct, f); - - /* H_t = act_cell(C_t) * ogated */ - o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); - _mm256_storeu_ps(ht, o); -} -#endif - -template void lstm_compute_ctht(float* gates, const float* ct_1, - float* ct, float* ht); -template void lstm_compute_ctht(double* gates, const double* ct_1, - double* ct, double* ht); - -} // namespace math +namespace math {} // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h index 244164f08c..28b6f71729 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.h +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,6 +11,11 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" +#ifdef __AVX__ +#include +#endif namespace paddle { namespace operators { @@ -21,7 +23,58 @@ namespace math { // TODO(TJ): ugly workaround, clean me template -void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht); +void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { + // gates: W_ch, W_ih, W_fh, W_oh + vec_sigmoid(24, gates + 8, gates + 8); + vec_tanh(8, gates, gates); + const T *i = gates + 8, *f = gates + 16, *o = gates + 24; + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int d = 0; d < 8; ++d) { + // C_t = C_t-1 * fgated + cand_gated * igated + ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; + // H_t = act_cell(C_t) * ogated + T tmp = ct[d] * 2; + tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); + vec_exp(1, &tmp, &tmp); + tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); + ht[d] = tmp * o[d]; + } +} + +#ifdef __AVX__ +namespace detail { +namespace forward { +namespace avx { +__m256 Sigmoid(const __m256 a); +__m256 Tanh(const __m256 a); +} // namespace avx +} // namespace forward +} // namespace detail + +template <> +void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, + float* ht) { + namespace act = detail::forward::avx; + // gates: W_ch, W_ih, W_fh, W_oh + __m256 c, i, f, o; + c = _mm256_loadu_ps(gates); + i = _mm256_loadu_ps(gates + 8); + f = _mm256_loadu_ps(gates + 16); + o = _mm256_loadu_ps(gates + 24); + + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); + i = _mm256_loadu_ps(ct_1); + f = _mm256_mul_ps(i, act::Sigmoid(f)); + f = _mm256_add_ps(c, f); + _mm256_storeu_ps(ct, f); + + /* H_t = act_cell(C_t) * ogated */ + o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); + _mm256_storeu_ps(ht, o); +} +#endif } // namespace math } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_desc_clone.py b/python/paddle/fluid/tests/unittests/test_desc_clone.py index 08579c7dd6..82e704169e 100644 --- a/python/paddle/fluid/tests/unittests/test_desc_clone.py +++ b/python/paddle/fluid/tests/unittests/test_desc_clone.py @@ -109,15 +109,20 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): return t +from paddle.fluid.transpiler.details import op_to_code + + def operator_equal(a, b): + if op_to_code(a) != op_to_code(b): + raise ValueError("In operator_equal not equal\n") + for k, v in six.iteritems(a.__dict__): if isinstance(v, fluid.framework.Program) or \ isinstance(v, fluid.framework.Block): continue elif isinstance(v, core.OpDesc): - if v.serialize_to_string() != b.__dict__[k].serialize_to_string(): - raise ValueError("In operator_equal not equal:{0}\n".format(k)) + continue elif isinstance(v, collections.OrderedDict): v0 = sorted(list(six.iteritems(v)), key=lambda x: x[0]) diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index a83aa0f11e..200175cfe8 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -113,27 +113,32 @@ def op_to_code(op): inputs_str += ", " inputs_str += "}" + attr_names = sorted(op.attr_names) attrs_str = "" - for i in range(0, len(op.attr_names)): - name = op.attr_names[i] + for i in range(0, len(attr_names)): + name = attr_names[i] attr_type = op.desc.attr_type(name) if attr_type == core.AttrType.BLOCK: a = "{name} = block[{value}]".format( name=name, type=attr_type, value=op.block_attr_id(name)) attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " continue if attr_type == core.AttrType.BLOCKS: a = "{name} = blocks{value}".format( name=name, type=attr_type, value=op.blocks_attr_ids(name)) attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " continue a = "{name} = {value}".format( name=name, type=attr_type, value=op.desc.attr(name)) attrs_str += a - if i != len(op.attr_names) - 1: + if i != len(attr_names) - 1: attrs_str += ", " if outputs_str != "{}":