From 01c6618de904e1d49660486cd65f8810cc9665a3 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sun, 8 Apr 2018 09:38:26 +0800 Subject: [PATCH 1/6] first wip commit --- .../fluid/framework/details/send_op_handle.cc | 78 +++++++++++++++++++ .../fluid/framework/details/send_op_handle.h | 50 ++++++++++++ paddle/fluid/operators/detail/grpc_client.cc | 3 +- 3 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/details/send_op_handle.cc create mode 100644 paddle/fluid/framework/details/send_op_handle.h diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc new file mode 100644 index 0000000000..bd2a0a9c29 --- /dev/null +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/send_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +SendOpHandle::SendOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap &ctxs) + : local_scopes_(local_scopes), places_(places) {} + +void SendOpHandle::RunImpl() { + if (inputs_.size() == 1) { + return; // No need to all reduce when GPU count = 1; + } else { + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + in->generated_op_->Wait(dev_ctxes_[p]); + } + + auto &var_name = static_cast(this->inputs_[0])->name_; + int dtype = -1; + size_t numel = 0; + + std::vector> all_reduce_calls; + + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + auto *s = local_scopes_[i]; + int dev_id = boost::get(p).device; + + auto &lod_tensor = s->FindVar(var_name)->Get(); + void *buffer = const_cast(lod_tensor.data()); + + if (dtype == -1) { + dtype = platform::ToNCCLDataType(lod_tensor.type()); + } + + if (numel == 0) { + numel = static_cast(lod_tensor.numel()); + } + + auto &nccl_ctx = nccl_ctxs_.at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(dtype), ncclSum, + comm, stream)); + }); + } + + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); + } + } +} + +std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h new file mode 100644 index 0000000000..515f1a10a8 --- /dev/null +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +struct SendOpHandle : public OpHandleBase { + const std::vector &local_scopes_; + const std::vector &places_; + const platform::NCCLContextMap &nccl_ctxs_; + + SendOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap &ctxs); + + std::string Name() const override; + + // Delay and buffer nccl_all_reduce together can significantly increase + // performance. Disable this feature by returning false. + bool IsMultiDeviceTransfer() override { return true; }; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ef987d07f0..3cf286575e 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -65,9 +65,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, } void ProcGetResponse(const VarHandle& var_h, - // const sendrecv::VariableMessage& ret_msg) { const ::grpc::ByteBuffer& ret_msg) { - framework::Variable* outvar = NULL; + framework::Variable* outvar = nullptr; DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); } From baea2cf17892f2cba47c8bde29bccd7488c2ee52 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Sun, 8 Apr 2018 18:35:49 +0800 Subject: [PATCH 2/6] wip --- paddle/fluid/framework/details/CMakeLists.txt | 1 + .../details/multi_devices_graph_builder.cc | 59 +++++++++++++---- .../details/multi_devices_graph_builder.h | 14 ++++- .../fluid/framework/details/send_op_handle.cc | 63 ++++--------------- .../fluid/framework/details/send_op_handle.h | 15 ++--- python/paddle/fluid/framework.py | 7 +++ 6 files changed, 87 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 89b5c6847f..caaf418076 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 128a5344fb..bea9489bbd 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/scope.h" #ifdef PADDLE_WITH_CUDA @@ -34,26 +35,46 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs) + platform::NCCLContextMap *nccl_ctxs, bool distributed) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes), + distributed_(distributed), nccl_ctxs_(nccl_ctxs) { #else MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes) + const std::vector &local_scopes, bool distributed) : loss_var_name_(loss_var_name), places_(places), - local_scopes_(local_scopes) { + local_scopes_(local_scopes), + distributed_(distributed) { #endif for (auto &p : params) { grad_names_.insert(GradVarName(p)); } } +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, + const platform::Place &p, + const size_t &i) const { + auto *op_handle = result->ops_.back().get(); + + auto var_names = op->InputArgumentNames(); + + for (auto &each_var_name : var_names) { + VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); + op_handle->AddInput(var); + } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(result, op_handle, each_var_name, p, i); + } +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { auto graph = new SSAGraph(); @@ -72,6 +93,17 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } } + // append send op if program is distributed trainer main program. + // always use the first device + if (is_forwarding && distributed_ && op->Type() == "send") { + auto &p = places_[0]; + auto *s = local_scopes_[0]; + size_t i = 0; + result.ops_.emplace_back(new SendOpHandle(*op, s, p)); + CreateOpHandleIOs(&result, op, p, i); + continue; + } + for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; auto *s = local_scopes_[i]; @@ -81,18 +113,19 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( op_handle->dev_ctxes_[p] = const_cast( platform::DeviceContextPool::Instance().Get(p)); - auto var_names = op->InputArgumentNames(); + CreateOpHandleIOs(&result, op, p, i); + // auto var_names = op->InputArgumentNames(); - for (auto &each_var_name : var_names) { - VarHandle *var = - CreateOrGetLatestVarHandle(&result, each_var_name, p, i); - op_handle->AddInput(var); - } - var_names = op->OutputArgumentNames(); + // for (auto &each_var_name : var_names) { + // VarHandle *var = + // CreateOrGetLatestVarHandle(&result, each_var_name, p, i); + // op_handle->AddInput(var); + // } + auto var_names = op->OutputArgumentNames(); - for (auto &each_var_name : var_names) { - CreateOpOutput(&result, op_handle, each_var_name, p, i); - } + // for (auto &each_var_name : var_names) { + // CreateOpOutput(&result, op_handle, each_var_name, p, i); + // } if (is_forwarding) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index d3c8e582cf..004d6d50ab 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/ssa_graph_builder.h" namespace paddle { @@ -31,21 +34,28 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs); + platform::NCCLContextMap *nccl_ctxs, + bool distributed = false); #else MultiDevSSAGraphBuilder(const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes); + const std::vector &local_scopes, + bool distributed = false); #endif std::unique_ptr Build(const ProgramDesc &program) const override; + private: + void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, + const size_t &i) const; + private: std::string loss_var_name_; const std::vector &places_; const std::vector &local_scopes_; std::unordered_set grad_names_; + bool distributed_; #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nccl_ctxs_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index bd2a0a9c29..ae5637b804 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -18,61 +18,24 @@ namespace paddle { namespace framework { namespace details { -SendOpHandle::SendOpHandle(const std::vector &local_scopes, - const std::vector &places, - const platform::NCCLContextMap &ctxs) - : local_scopes_(local_scopes), places_(places) {} +SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, + const Scope *local_scope, + const platform::Place &place) + : op_(framework::OpRegistry::CreateOp(op_desc)), + local_scope_(local_scope), + place_(place) {} void SendOpHandle::RunImpl() { - if (inputs_.size() == 1) { - return; // No need to all reduce when GPU count = 1; - } else { - // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - in->generated_op_->Wait(dev_ctxes_[p]); - } - - auto &var_name = static_cast(this->inputs_[0])->name_; - int dtype = -1; - size_t numel = 0; - - std::vector> all_reduce_calls; - - for (size_t i = 0; i < local_scopes_.size(); ++i) { - auto &p = places_[i]; - auto *s = local_scopes_[i]; - int dev_id = boost::get(p).device; - - auto &lod_tensor = s->FindVar(var_name)->Get(); - void *buffer = const_cast(lod_tensor.data()); - - if (dtype == -1) { - dtype = platform::ToNCCLDataType(lod_tensor.type()); - } - - if (numel == 0) { - numel = static_cast(lod_tensor.numel()); - } - - auto &nccl_ctx = nccl_ctxs_.at(dev_id); - auto stream = nccl_ctx.stream(); - auto comm = nccl_ctx.comm_; - all_reduce_calls.emplace_back([=] { - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - buffer, buffer, numel, static_cast(dtype), ncclSum, - comm, stream)); - }); - } - - platform::NCCLGroupGuard guard; - for (auto &call : all_reduce_calls) { - call(); - } + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + in->generated_op_->Wait(dev_ctxes_[p]); } + + op_->Run(*local_scope_, place_); } -std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } +std::string SendOpHandle::Name() const { return "send"; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index 515f1a10a8..e7857c1f23 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -19,6 +19,8 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/nccl_helper.h" @@ -27,19 +29,18 @@ namespace framework { namespace details { struct SendOpHandle : public OpHandleBase { - const std::vector &local_scopes_; - const std::vector &places_; - const platform::NCCLContextMap &nccl_ctxs_; + std::unique_ptr op_; + const Scope* local_scope_; + const platform::Place& place_; - SendOpHandle(const std::vector &local_scopes, - const std::vector &places, - const platform::NCCLContextMap &ctxs); + SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place); std::string Name() const override; // Delay and buffer nccl_all_reduce together can significantly increase // performance. Disable this feature by returning false. - bool IsMultiDeviceTransfer() override { return true; }; + bool IsMultiDeviceTransfer() override { return false; }; protected: void RunImpl() override; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 39d4017861..8bd9161fcb 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -951,6 +951,13 @@ class Block(object): if var.type == core.VarDesc.VarType.STEP_SCOPES: ret_var = self.create_var( name=var.name, persistable=var.persistable, type=var.type) + elif var.type == core.VarDesc.VarType.SELECTED_ROWS: + ret_var = self.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + persistable=True) else: ret_var = self.create_var( name=var.name, From 0bf799a52388dd77743623dcb2d1ebacb352858b Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 10 Apr 2018 17:00:06 +0800 Subject: [PATCH 3/6] wip testing --- paddle/fluid/framework/details/CMakeLists.txt | 2 +- .../framework/details/multi_devices_graph_builder.cc | 10 ++++------ .../framework/details/multi_devices_graph_builder.h | 7 ++----- paddle/fluid/framework/parallel_executor.h | 4 ++-- paddle/fluid/operators/detail/serde_test.cc | 2 +- paddle/fluid/pybind/pybind.cc | 1 + python/paddle/fluid/parallel_executor.py | 7 +++++-- 7 files changed, 16 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index caaf418076..85b649b293 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -16,7 +16,7 @@ else() set(multi_devices_graph_builder_deps) endif() cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle ${multi_devices_graph_builder_deps}) + scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 8a28b18715..8a53270110 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -35,22 +35,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, bool distributed) + platform::NCCLContextMap *nccl_ctxs) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes), - distributed_(distributed), nccl_ctxs_(nccl_ctxs) { #else MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes, bool distributed) + const std::vector &local_scopes) : loss_var_name_(loss_var_name), places_(places), - local_scopes_(local_scopes), - distributed_(distributed) { + local_scopes_(local_scopes) { #endif for (auto &p : params) { grad_names_.insert(GradVarName(p)); @@ -99,7 +97,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // append send op if program is distributed trainer main program. // always use the first device - if (is_forwarding && distributed_ && op->Type() == "send") { + if (!is_forwarding && op->Type() == "send") { auto &p = places_[0]; auto *s = local_scopes_[0]; size_t i = 0; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 004d6d50ab..de34caab1b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -34,14 +34,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, - bool distributed = false); + platform::NCCLContextMap *nccl_ctxs); #else MultiDevSSAGraphBuilder(const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes, - bool distributed = false); + const std::vector &local_scopes); #endif std::unique_ptr Build(const ProgramDesc &program) const override; @@ -55,7 +53,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &places_; const std::vector &local_scopes_; std::unordered_set grad_names_; - bool distributed_; #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nccl_ctxs_; diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index c048c3865f..b4f16dba85 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -48,13 +48,13 @@ class ParallelExecutor { const std::string& fetched_var_name, const std::unordered_map& feed_tensors); + void BCastParamsToGPUs(const std::unordered_set& vars) const; + private: void SplitTensorToPlaces( const std::unordered_map& feed_tensors); ParallelExecutorPrivate* member_; - - void BCastParamsToGPUs(const std::unordered_set& vars) const; }; } // namespace framework diff --git a/paddle/fluid/operators/detail/serde_test.cc b/paddle/fluid/operators/detail/serde_test.cc index f8cae6b26a..cb5f895834 100644 --- a/paddle/fluid/operators/detail/serde_test.cc +++ b/paddle/fluid/operators/detail/serde_test.cc @@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { for (int i = 0; i < tensor_numel; ++i) { EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); } - for (int64_t i = 0; i < rows2->size(); ++i) { + for (size_t i = 0; i < rows2->size(); ++i) { EXPECT_EQ(rows_data2[i], i); } EXPECT_EQ(slr2->height(), 1000); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3924040455..a9a5d87d77 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -554,6 +554,7 @@ All parameter, weight, gradient are variables in Paddle. bcast_vars, main_program, loss_var_name, scope, local_scopes, allow_op_delay); }) + .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) .def("local_scopes", [](ParallelExecutor &self) -> std::vector * { return &self.GetLocalScopes(); diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index b93f2f974c..a23cc9b772 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -99,7 +99,7 @@ class ParallelExecutor(object): local_scopes = share_vars_from.executor.local_scopes( ) if share_vars_from else [] - persistable_vars = [ + self.persistable_vars = [ v.name for v in filter(lambda var: var.persistable, main.list_vars()) ] @@ -112,7 +112,7 @@ class ParallelExecutor(object): p.name for p in main.global_block().iter_parameters() if not p.stop_gradient ]), - set(persistable_vars), + set(self.persistable_vars), main.desc, loss_name if loss_name else '', scope, @@ -142,3 +142,6 @@ class ParallelExecutor(object): self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() return [arr[i] for i in range(len(arr))] + + def bcast_params(self): + self.executor.bcast_params(set(self.persistable_vars)) From ce08dc8751b5f605ce6aece70ce6f16af72f4759 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 10 Apr 2018 20:40:43 +0800 Subject: [PATCH 4/6] have stream removed error --- .../details/multi_devices_graph_builder.cc | 34 ++++++++----------- .../details/multi_devices_graph_builder.h | 2 +- .../fluid/framework/details/send_op_handle.cc | 10 +++--- .../fluid/framework/details/send_op_handle.h | 4 +-- python/paddle/fluid/distribute_transpiler.py | 1 + python/paddle/fluid/parallel_executor.py | 4 ++- 6 files changed, 24 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 8a53270110..0ebcd627bd 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -57,8 +57,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i) const { + const size_t &i, + bool create_output) const { auto *op_handle = result->ops_.back().get(); + op_handle->dev_ctxes_[p] = const_cast( + platform::DeviceContextPool::Instance().Get(p)); auto var_names = op->InputArgumentNames(); @@ -66,10 +69,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } - var_names = op->OutputArgumentNames(); + if (create_output) { + var_names = op->OutputArgumentNames(); - for (auto &each_var_name : var_names) { - CreateOpOutput(result, op_handle, each_var_name, p, i); + for (auto &each_var_name : var_names) { + CreateOpOutput(result, op_handle, each_var_name, p, i); + } } } @@ -100,9 +105,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (!is_forwarding && op->Type() == "send") { auto &p = places_[0]; auto *s = local_scopes_[0]; - size_t i = 0; - result.ops_.emplace_back(new SendOpHandle(*op, s, p)); - CreateOpHandleIOs(&result, op, p, i); + // FIXME(wuyi): send op always copy from GPU 0 + result.ops_.emplace_back(new SendOpHandle(*op, s)); + // Create inputs for output on original place and no ssa output + // is created for send op. + CreateOpHandleIOs(&result, op, p, 0, false); continue; } @@ -112,23 +119,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); auto *op_handle = result.ops_.back().get(); - op_handle->dev_ctxes_[p] = const_cast( - platform::DeviceContextPool::Instance().Get(p)); - CreateOpHandleIOs(&result, op, p, i); - // auto var_names = op->InputArgumentNames(); - // for (auto &each_var_name : var_names) { - // VarHandle *var = - // CreateOrGetLatestVarHandle(&result, each_var_name, p, i); - // op_handle->AddInput(var); - // } auto var_names = op->OutputArgumentNames(); - // for (auto &each_var_name : var_names) { - // CreateOpOutput(&result, op_handle, each_var_name, p, i); - // } - if (is_forwarding) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) { // Insert ScaleCost OpHandle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index de34caab1b..137c817fde 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i) const; + const size_t &i, bool create_output = true) const; private: std::string loss_var_name_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index ae5637b804..caacfa6b1e 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -19,11 +19,9 @@ namespace framework { namespace details { SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope, - const platform::Place &place) + const Scope *local_scope) : op_(framework::OpRegistry::CreateOp(op_desc)), - local_scope_(local_scope), - place_(place) {} + local_scope_(local_scope) {} void SendOpHandle::RunImpl() { // Wait input done @@ -31,8 +29,8 @@ void SendOpHandle::RunImpl() { auto &p = static_cast(in)->place_; in->generated_op_->Wait(dev_ctxes_[p]); } - - op_->Run(*local_scope_, place_); + platform::CPUPlace cpu; + op_->Run(*local_scope_, cpu); } std::string SendOpHandle::Name() const { return "send"; } diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index e7857c1f23..8a7b62ba1c 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -31,10 +31,8 @@ namespace details { struct SendOpHandle : public OpHandleBase { std::unique_ptr op_; const Scope* local_scope_; - const platform::Place& place_; - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const platform::Place& place); + SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope); std::string Name() const override; diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 0ec3ebc7e3..e18ace844e 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -255,6 +255,7 @@ class DistributeTranspiler: def get_trainer_program(self): # remove optimize ops and add a send op to main_program self.program.global_block().delete_ops(self.optimize_ops) + self.program.sync_with_cpp() # FIXME(typhoonzero): serialize once will fix error occurs when clone. self.program.__str__() return self.program diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index a23cc9b772..c709f364c1 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -101,7 +101,9 @@ class ParallelExecutor(object): self.persistable_vars = [ v.name - for v in filter(lambda var: var.persistable, main.list_vars()) + for v in filter(lambda var: \ + var.persistable and var.type != core.VarDesc.VarType.RAW, + main.list_vars()) ] self.executor = core.ParallelExecutor( From 16a9dfe4805fa88670338b52bf898f60043fc16f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 11 Apr 2018 15:12:58 +0800 Subject: [PATCH 5/6] finish --- .../details/multi_devices_graph_builder.cc | 16 +++++++--------- .../details/multi_devices_graph_builder.h | 2 +- paddle/fluid/framework/details/send_op_handle.cc | 12 ++++++++---- paddle/fluid/framework/details/send_op_handle.h | 4 +++- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0ebcd627bd..e0dd9e6068 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i, - bool create_output) const { + const size_t &i) const { auto *op_handle = result->ops_.back().get(); op_handle->dev_ctxes_[p] = const_cast( platform::DeviceContextPool::Instance().Get(p)); @@ -69,12 +68,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } - if (create_output) { - var_names = op->OutputArgumentNames(); - for (auto &each_var_name : var_names) { - CreateOpOutput(result, op_handle, each_var_name, p, i); - } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(result, op_handle, each_var_name, p, i); } } @@ -106,10 +104,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( auto &p = places_[0]; auto *s = local_scopes_[0]; // FIXME(wuyi): send op always copy from GPU 0 - result.ops_.emplace_back(new SendOpHandle(*op, s)); + result.ops_.emplace_back(new SendOpHandle(*op, s, p)); // Create inputs for output on original place and no ssa output // is created for send op. - CreateOpHandleIOs(&result, op, p, 0, false); + CreateOpHandleIOs(&result, op, p, 0); continue; } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 137c817fde..de34caab1b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i, bool create_output = true) const; + const size_t &i) const; private: std::string loss_var_name_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index caacfa6b1e..d181607e86 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -19,18 +19,22 @@ namespace framework { namespace details { SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope) + const Scope *local_scope, + const platform::Place &place) : op_(framework::OpRegistry::CreateOp(op_desc)), - local_scope_(local_scope) {} + local_scope_(local_scope), + place_(place) {} void SendOpHandle::RunImpl() { // Wait input done for (auto *in : inputs_) { auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } in->generated_op_->Wait(dev_ctxes_[p]); } - platform::CPUPlace cpu; - op_->Run(*local_scope_, cpu); + op_->Run(*local_scope_, place_); } std::string SendOpHandle::Name() const { return "send"; } diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index 8a7b62ba1c..e7857c1f23 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -31,8 +31,10 @@ namespace details { struct SendOpHandle : public OpHandleBase { std::unique_ptr op_; const Scope* local_scope_; + const platform::Place& place_; - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope); + SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place); std::string Name() const override; From d1e63a1d9205e99483a3b69058fdf36e54dc348e Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 11 Apr 2018 15:18:55 +0800 Subject: [PATCH 6/6] fix ci --- paddle/fluid/framework/details/send_op_handle.h | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index e7857c1f23..173f9d7261 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -22,7 +22,6 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/platform/nccl_helper.h" namespace paddle { namespace framework {