From 08acc035223589f4c1c35623c4158b027d71b8c6 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 10 Jul 2018 17:33:27 +0800 Subject: [PATCH 01/22] start --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/ir/CMakeLists.txt | 3 +++ paddle/fluid/framework/ir/graph.cc | 19 +++++++++++++++++++ paddle/fluid/framework/ir/graph.h | 19 +++++++++++++++++++ paddle/fluid/framework/ir/node.cc | 19 +++++++++++++++++++ paddle/fluid/framework/ir/node.h | 19 +++++++++++++++++++ paddle/fluid/framework/ir/pass.cc | 19 +++++++++++++++++++ paddle/fluid/framework/ir/pass.h | 19 +++++++++++++++++++ 8 files changed, 118 insertions(+) create mode 100644 paddle/fluid/framework/ir/CMakeLists.txt create mode 100644 paddle/fluid/framework/ir/graph.cc create mode 100644 paddle/fluid/framework/ir/graph.h create mode 100644 paddle/fluid/framework/ir/node.cc create mode 100644 paddle/fluid/framework/ir/node.h create mode 100644 paddle/fluid/framework/ir/pass.cc create mode 100644 paddle/fluid/framework/ir/pass.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index ec252929d5..bae8f51bcf 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(details) +add_subdirectory(ir) # ddim lib proto_library(framework_proto SRCS framework.proto) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt new file mode 100644 index 0000000000..4cd373e8ea --- /dev/null +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -0,0 +1,3 @@ +cc_library(graph SRCS graph.cc) +cc_library(node SRCS node.cc) +cc_library(pass SRCS pass.cc) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc new file mode 100644 index 0000000000..b5c5ba7c14 --- /dev/null +++ b/paddle/fluid/framework/ir/graph.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h new file mode 100644 index 0000000000..6f4bb172c6 --- /dev/null +++ b/paddle/fluid/framework/ir/graph.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc new file mode 100644 index 0000000000..ca83fa7a83 --- /dev/null +++ b/paddle/fluid/framework/ir/node.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h new file mode 100644 index 0000000000..6f4bb172c6 --- /dev/null +++ b/paddle/fluid/framework/ir/node.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc new file mode 100644 index 0000000000..91b0decd25 --- /dev/null +++ b/paddle/fluid/framework/ir/pass.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h new file mode 100644 index 0000000000..6f4bb172c6 --- /dev/null +++ b/paddle/fluid/framework/ir/pass.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle From fcda23a3e4a336a4d69376af240f0a6bd3d0d546 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 10 Jul 2018 19:45:27 +0800 Subject: [PATCH 02/22] simple node --- paddle/fluid/framework/ir/node.h | 42 +++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 6f4bb172c6..0c52127069 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,6 +14,46 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include "paddle/fluid/platform/macros.h" + namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +class Node { + public: + enum class Type { kNone = -1, kOperation, kVariable }; + + Node() {} + virtual ~Node() {} + + template + Subclass &As() { + return *dynamic_cast(this); + } + + int64_t ID() const { return id_; } + + std::string Name() const { return name_; } + + virtual std::string ToString() const { + return Name() + "(" + std::to_string(ID()) + ")"; + } + + Type NodeType() const { return type_; } + + std::vector inputs; + std::vector outputs; + + protected: + int64_t id_ = 0; + std::string name_; + Type type_; + + DISABLE_COPY_AND_ASSIGN(Node); +}; + +} // namespace framework } // namespace paddle From 7781297c701d3c8d72d8d252e079a5f3fe5ecde3 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 10 Jul 2018 21:12:10 +0800 Subject: [PATCH 03/22] variants --- paddle/fluid/framework/details/var_handle.h | 1 + paddle/fluid/framework/ir/graph.h | 21 ++++++++++++++++++++- paddle/fluid/framework/ir/node.h | 10 +++++----- paddle/fluid/platform/variant.h | 1 + 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index cae9af7217..c62f9a9d08 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -18,6 +18,7 @@ #include #include +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/platform/place.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 6f4bb172c6..d1805d7434 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -14,6 +14,25 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/platform/variant.h" + namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +class Graph { + public: + std::map> attrs; + + std::vector inputs; + std::vector outputs; + std::vector> nodes; +}; + +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 0c52127069..9a280afb3b 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -15,9 +15,11 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { @@ -29,11 +31,6 @@ class Node { Node() {} virtual ~Node() {} - template - Subclass &As() { - return *dynamic_cast(this); - } - int64_t ID() const { return id_; } std::string Name() const { return name_; } @@ -42,12 +39,15 @@ class Node { return Name() + "(" + std::to_string(ID()) + ")"; } + virtual std::string DebugString() const = 0; + Type NodeType() const { return type_; } std::vector inputs; std::vector outputs; protected: + std::map> attrs_; int64_t id_ = 0; std::string name_; Type type_; diff --git a/paddle/fluid/platform/variant.h b/paddle/fluid/platform/variant.h index 45f60fc9d7..dc9fad29f2 100644 --- a/paddle/fluid/platform/variant.h +++ b/paddle/fluid/platform/variant.h @@ -38,6 +38,7 @@ limitations under the License. */ #endif #endif +#include #include #include #include From 2eeaa8d5cff85532e84b95749c13819724fd1d51 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 11 Jul 2018 09:37:42 +0800 Subject: [PATCH 04/22] Graph in ParallelExecutor Builder --- .../details/multi_devices_graph_builder.cc | 134 +++++++++++------- .../details/multi_devices_graph_builder.h | 24 ++-- .../framework/details/ssa_graph_builder.cc | 27 ++-- .../framework/details/ssa_graph_builder.h | 17 ++- paddle/fluid/framework/ir/graph.h | 2 +- paddle/fluid/framework/ir/pass.h | 23 ++- 6 files changed, 151 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 6f5d4471a9..da0272d48e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" @@ -66,11 +67,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, - const OpDesc &op, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -169,18 +170,21 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { + std::unique_ptr graph(new Graph); for (auto *var : program.Block(0).AllVars()) { all_vars_.emplace(var->Name(), var); } - auto graph = new SSAGraph(); - SSAGraph &result = *graph; + Graph &result = *graph; std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.vars_ = std::vector< + result.attrs["vars"] = new std::vector< std::unordered_map>>>( places_.size()); + result.attrs["dep_vars"] = + new std::unordered_set>(); + result.attrs["ops"] = new std::vector>(); // find send/recv vars so that we can place the distributed training // realted op in the place 0 @@ -303,7 +307,15 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( */ AddOutputToLeafOps(&result); - return std::unique_ptr(graph); + std::unique_ptr ssa_graph(new SSAGraph); + ssa_graph->vars_ = + std::move(*boost::any_cast(graph->attrs["vars"])); + ssa_graph->ops_ = + std::move(*boost::any_cast(graph->attrs["ops"])); + ssa_graph->dep_vars_ = + std::move(*boost::any_cast(graph->attrs["dep_vars"])); + + return std::move(ssa_graph); } bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { @@ -327,7 +339,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( #endif } -void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA @@ -336,42 +348,50 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); #endif - result->ops_.emplace_back(op_handle); - auto *in = result->vars_.at(src_dev_id).at(p_name).back().get(); + boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); + auto *in = boost::any_cast(result->attrs["vars"]) + ->at(src_dev_id) + .at(p_name) + .back() + .get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_.at(i).at(p_name); + auto &vars = + boost::any_cast(result->attrs["vars"])->at(i).at(p_name); auto *out_var = new VarHandle(vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } } -void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const { - result->ops_.emplace_back( - new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); + boost::any_cast(result->attrs["ops"]) + ->emplace_back( + new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, op, dev_id); } -void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new AllReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_[i][og]; + auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); @@ -383,19 +403,23 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, } void MultiDevSSAGraphBuilder::InsertDataBalanceOp( - SSAGraph *result, const std::vector &datas) const { + Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = result->vars_[i][d_name]; + auto &vars = + (*boost::any_cast(result->attrs["vars"]))[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle(vars.size(), i, d_name, p); @@ -441,7 +465,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { return got == var_name_on_devices_.end() ? -1 : got->second; } -void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { +void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA @@ -456,7 +480,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { auto *op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); - result->ops_.emplace_back(op_handle); + boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -469,37 +493,41 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { } } -void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, const OpDesc &op, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new ComputationOpHandle(op, s, p)); CreateOpHandleIOs(result, op, scope_idx); } } -VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, +VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new ReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_[i][og]; + auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = result->vars_[dst_dev_id][og]; + auto &vars = + (*boost::any_cast(result->attrs["vars"]))[dst_dev_id][og]; auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -508,19 +536,20 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, // Find the first occurence of `prev_op_name` and make current `op` depend // on it. -void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, +void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : result->ops_) { + for (auto &prev_op : (*boost::any_cast(result->attrs["ops"]))) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(); prev_op->AddOutput(dep_var); - result->dep_vars_.emplace(dep_var); + boost::any_cast(result->attrs["dep_vars"]) + ->emplace(dep_var); op->AddInput(dep_var); } } } -void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, const OpDesc &op) const { int op_dev_id = -1; if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { @@ -550,12 +579,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, CreateComputationalOp(result, op, op_dev_id); if (op.Type() == "concat") { - ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "fetch_barrier"); } } // Create RPC related op handles that connects its in ops and out ops. -void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, const OpDesc &op) const { int op_dev_id = -1; if (op.Type() == "send") { @@ -584,15 +615,22 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", op.Type()); - result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], - op.Type(), places_[op_dev_id])); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(), + places_[op_dev_id])); if (op.Type() == "send_barrier") { - ConnectOp(result, result->ops_.back().get(), "send"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "send"); } else if (op.Type() == "recv") { - ConnectOp(result, result->ops_.back().get(), "send_barrier"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "send_barrier"); } else if (op.Type() == "fetch_barrier") { - ConnectOp(result, result->ops_.back().get(), "recv"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "recv"); } else if (op.Type() == "send") { // do nothing } else { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index a964e02488..3d7642f522 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace platform { @@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { int GetVarDeviceID(const std::string &varname) const override; private: - void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, + void CreateOpHandleIOs(Graph *result, const OpDesc &op, size_t device_id) const; private: @@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; - void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; - void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(Graph *result, const OpDesc &op) const; + void CreateDistTrainOp(Graph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. @@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::vector FindDistTrainRecvVars( const ProgramDesc &program) const; - void ConnectOp(SSAGraph *result, OpHandleBase *op, + void ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const; - void CreateComputationalOps(SSAGraph *result, const OpDesc &op, + void CreateComputationalOps(Graph *result, const OpDesc &op, size_t num_places) const; - void CreateScaleLossGradOp(SSAGraph *result) const; - VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og, + void CreateScaleLossGradOp(Graph *result) const; + VarHandle *CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const; - void CreateComputationalOp(SSAGraph *result, const OpDesc &op, - int dev_id) const; + void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const; bool IsParameterGradientOnce( const std::string &og, @@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { int GetOpDeviceID(const OpDesc &op) const; - void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; + void InsertAllReduceOp(Graph *result, const std::string &og) const; - void InsertDataBalanceOp(SSAGraph *result, + void InsertDataBalanceOp(Graph *result, const std::vector &datas) const; - void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, + void CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const; bool IsSparseGradient(const std::string &og) const; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 88a21f4887..2c0873cc87 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,8 +17,8 @@ namespace paddle { namespace framework { namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { - for (auto &var_map : graph->vars_) { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { + for (auto &var_map : *boost::any_cast(graph->attrs["vars"])) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { auto *dep_var = new DummyVarHandle(); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - graph->dep_vars_.emplace(dep_var); + boost::any_cast(graph->attrs["dep_vars"]) + ->emplace(dep_var); } } } @@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { } VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - SSAGraph *graph, const std::string &each_var_name, + Graph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &var_holders = graph->vars_[place_offset]; + auto &var_holders = + (*boost::any_cast(graph->attrs["vars"]))[place_offset]; auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( return var; } -void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, +void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &vars = graph->vars_[place_offset][each_var_name]; + auto &vars = + (*boost::any_cast(graph->attrs["vars"]))[place_offset] + [each_var_name]; size_t version = vars.size(); auto var = new VarHandle(version, place_offset, each_var_name, place); vars.emplace_back(var); op_handle->AddOutput(var); } -void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { - for (auto &op : graph->ops_) { +void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { + GraphOps &all_ops = *boost::any_cast(graph->attrs["ops"]); + + for (auto &op : all_ops) { if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(); - graph->dep_vars_.emplace(dummy_leaf); + boost::any_cast(graph->attrs["dep_vars"]) + ->emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 18612c3c1b..d5aabb9fd1 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -16,15 +16,24 @@ #include #include +#include #include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/framework/ir/graph.h" + namespace paddle { namespace framework { namespace details { +typedef std::vector< + std::unordered_map>>> + GraphVars; +typedef std::unordered_set> GraphDepVars; +typedef std::vector> GraphOps; + class SSAGraphBuilder { public: SSAGraphBuilder() {} @@ -42,20 +51,20 @@ class SSAGraphBuilder { * * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ - static void PolishGraphToSupportDataHazards(SSAGraph *graph); + static void PolishGraphToSupportDataHazards(Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, + static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset); // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph - static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset); - static void AddOutputToLeafOps(SSAGraph *graph); + static void AddOutputToLeafOps(Graph *graph); }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index d1805d7434..8996c2d43a 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -27,7 +27,7 @@ namespace framework { class Graph { public: - std::map> attrs; + std::map attrs; std::vector inputs; std::vector outputs; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 6f4bb172c6..087ebb8709 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -14,6 +14,27 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" + namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +class Pass { + public: + Pass() = default; + virtual ~Pass() {} + virtual std::unique_ptr Apply(std::unique_ptr graph) { + return std::move(graph); + } +}; + +std::unique_ptr ProgramToGraph(const ProgramDesc& program) { + std::unique_ptr g(new Graph); + + return std::move(g); +} + +} // namespace framework } // namespace paddle From 9b9603306c1b2c8fc0ec6ea54b0b289ab974b97b Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 11 Jul 2018 19:59:28 +0800 Subject: [PATCH 05/22] graph attrs --- .../details/multi_devices_graph_builder.cc | 115 +++++++----------- .../framework/details/ssa_graph_builder.cc | 17 +-- paddle/fluid/framework/ir/graph.h | 65 +++++++++- 3 files changed, 111 insertions(+), 86 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index da0272d48e..9ac961f1b1 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -70,8 +70,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -179,13 +178,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.attrs["vars"] = new std::vector< - std::unordered_map>>>( - places_.size()); - result.attrs["dep_vars"] = - new std::unordered_set>(); - result.attrs["ops"] = new std::vector>(); - + result.Set("vars", new GraphVars(places_.size())); + result.Set("dep_vars", new GraphDepVars); + result.Set("ops", new GraphOps); // find send/recv vars so that we can place the distributed training // realted op in the place 0 auto send_vars = FindDistTrainSendVars(program); @@ -308,13 +303,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); std::unique_ptr ssa_graph(new SSAGraph); - ssa_graph->vars_ = - std::move(*boost::any_cast(graph->attrs["vars"])); - ssa_graph->ops_ = - std::move(*boost::any_cast(graph->attrs["ops"])); - ssa_graph->dep_vars_ = - std::move(*boost::any_cast(graph->attrs["dep_vars"])); - + ssa_graph->vars_ = std::move(*graph->Erase("vars")); + ssa_graph->ops_ = std::move(*graph->Erase("ops")); + ssa_graph->dep_vars_ = std::move(*graph->Erase("dep_vars")); return std::move(ssa_graph); } @@ -347,20 +338,15 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, #else auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); #endif - - boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); - auto *in = boost::any_cast(result->attrs["vars"]) - ->at(src_dev_id) - .at(p_name) - .back() - .get(); + result->Get("ops").emplace_back(op_handle); + auto *in = + result->Get("vars").at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = - boost::any_cast(result->attrs["vars"])->at(i).at(p_name); + auto &vars = result->Get("vars").at(i).at(p_name); auto *out_var = new VarHandle(vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); @@ -370,28 +356,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const { - boost::any_cast(result->attrs["ops"]) - ->emplace_back( - new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); + result->Get("ops").emplace_back( + new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, op, dev_id); } void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new AllReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back( + new AllReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; + auto &vars = result->Get("vars")[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); @@ -405,21 +389,18 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - boost::any_cast(result->attrs["ops"]) - ->emplace_back( - new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back( + new DataBalanceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = - (*boost::any_cast(result->attrs["vars"]))[i][d_name]; + auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle(vars.size(), i, d_name, p); @@ -480,7 +461,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { auto *op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); - boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); + result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -499,8 +480,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new ComputationOpHandle(op, s, p)); + result->Get("ops").emplace_back( + new ComputationOpHandle(op, s, p)); CreateOpHandleIOs(result, op, scope_idx); } } @@ -509,25 +490,23 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new ReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back( + new ReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = - boost::any_cast(result->attrs["ops"])->back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; + auto &vars = result->Get("vars")[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = - (*boost::any_cast(result->attrs["vars"]))[dst_dev_id][og]; + auto &vars = result->Get("vars")[dst_dev_id][og]; auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -538,12 +517,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, // on it. void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : (*boost::any_cast(result->attrs["ops"]))) { + for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(); prev_op->AddOutput(dep_var); - boost::any_cast(result->attrs["dep_vars"]) - ->emplace(dep_var); + result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); } } @@ -579,8 +557,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, CreateComputationalOp(result, op, op_dev_id); if (op.Type() == "concat") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), + ConnectOp(result, result->Get("ops").back().get(), "fetch_barrier"); } } @@ -615,22 +592,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", op.Type()); - boost::any_cast(result->attrs["ops"]) - ->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(), - places_[op_dev_id])); + result->Get("ops").emplace_back(new RPCOpHandle( + op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id])); if (op.Type() == "send_barrier") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), - "send"); + ConnectOp(result, result->Get("ops").back().get(), "send"); } else if (op.Type() == "recv") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), + ConnectOp(result, result->Get("ops").back().get(), "send_barrier"); } else if (op.Type() == "fetch_barrier") { - ConnectOp(result, - boost::any_cast(result->attrs["ops"])->back().get(), - "recv"); + ConnectOp(result, result->Get("ops").back().get(), "recv"); } else if (op.Type() == "send") { // do nothing } else { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 2c0873cc87..2508ed0296 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -18,7 +18,7 @@ namespace paddle { namespace framework { namespace details { void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { - for (auto &var_map : *boost::any_cast(graph->attrs["vars"])) { + for (auto &var_map : graph->Get("vars")) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -40,8 +40,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { auto *dep_var = new DummyVarHandle(); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - boost::any_cast(graph->attrs["dep_vars"]) - ->emplace(dep_var); + graph->Get("dep_vars").emplace(dep_var); } } } @@ -51,8 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( Graph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &var_holders = - (*boost::any_cast(graph->attrs["vars"]))[place_offset]; + auto &var_holders = graph->Get("vars")[place_offset]; auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -68,9 +66,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &vars = - (*boost::any_cast(graph->attrs["vars"]))[place_offset] - [each_var_name]; + auto &vars = graph->Get("vars")[place_offset][each_var_name]; size_t version = vars.size(); auto var = new VarHandle(version, place_offset, each_var_name, place); vars.emplace_back(var); @@ -78,15 +74,14 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, } void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { - GraphOps &all_ops = *boost::any_cast(graph->attrs["ops"]); + GraphOps &all_ops = graph->Get("ops"); for (auto &op : all_ops) { if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(); - boost::any_cast(graph->attrs["dep_vars"]) - ->emplace(dummy_leaf); + graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 8996c2d43a..f1de4d740d 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -20,18 +20,77 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { +class Graph; + +template +struct AnyAttr { + public: + explicit AnyAttr(AttrType* attr) : attr_(attr) {} + + AttrType& Get() { return *boost::any_cast(attr_); } + + private: + friend Graph; + + AttrType* Release() { + released_ = true; + return boost::any_cast(attr_); + } + + void Delete() { + if (!released_) { + delete boost::any_cast(attr_); + } + } + + bool released_ = false; + boost::any attr_; +}; + class Graph { public: - std::map attrs; + virtual ~Graph() { + for (auto& attr : attrs) { + attr_dels[attr.first](); + } + attrs.clear(); + attr_dels.clear(); + } + + template + AttrType& Get(const std::string& attr_name) { + return boost::any_cast>(attrs[attr_name]).Get(); + } + + template + void Set(const std::string& attr_name, AttrType* attr) { + AnyAttr any_attr = AnyAttr(attr); + attrs[attr_name] = any_attr; + attr_dels[attr_name] = [&any_attr]() { any_attr.Delete(); }; + } - std::vector inputs; - std::vector outputs; + template + AttrType* Erase(const std::string& attr_name) { + AnyAttr attr_type = + boost::any_cast>(attrs[attr_name]); + attrs.erase(attr_name); + attr_dels.erase(attr_name); + return attr_type.Release(); + } + + std::vector inputs; + std::vector outputs; std::vector> nodes; + std::map attrs; + std::map> attr_dels; + + private: }; } // namespace framework From 68aa5004512cd12e8d81b08d2fe40ddcdfb59f2f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 11 Jul 2018 20:18:31 +0800 Subject: [PATCH 06/22] polish attrs --- .../details/multi_devices_graph_builder.cc | 9 +-- .../details/multi_devices_graph_builder.h | 2 +- .../framework/details/ssa_graph_builder.h | 2 +- .../framework/details/ssa_graph_checker.cc | 8 +-- .../framework/details/ssa_graph_checker.h | 4 +- .../framework/details/ssa_graph_printer.cc | 10 ++-- .../framework/details/ssa_graph_printer.h | 6 +- paddle/fluid/framework/ir/graph.h | 60 ++++++------------- paddle/fluid/framework/parallel_executor.cc | 13 +++- 9 files changed, 46 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 9ac961f1b1..9be4963c91 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -167,7 +167,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( return dev_id; } -std::unique_ptr MultiDevSSAGraphBuilder::Build( +std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { std::unique_ptr graph(new Graph); for (auto *var : program.Block(0).AllVars()) { @@ -301,12 +301,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - - std::unique_ptr ssa_graph(new SSAGraph); - ssa_graph->vars_ = std::move(*graph->Erase("vars")); - ssa_graph->ops_ = std::move(*graph->Erase("ops")); - ssa_graph->dep_vars_ = std::move(*graph->Erase("dep_vars")); - return std::move(ssa_graph); + return std::move(graph); } bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 3d7642f522..b9504665d0 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const BuildStrategy &strategy); #endif - std::unique_ptr Build(const ProgramDesc &program) const override; + std::unique_ptr Build(const ProgramDesc &program) const override; int GetVarDeviceID(const std::string &varname) const override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index d5aabb9fd1..56c3077cb3 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -38,7 +38,7 @@ class SSAGraphBuilder { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index da5428946e..c01334ca06 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { namespace details { -bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { +bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { std::unordered_map pending_ops; std::unordered_set pending_vars; std::unordered_set ready_vars; @@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { } }; - for (auto &var_map : graph->vars_) { + for (auto &var_map : graph->Get("vars")) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { insert_pending_var(version_pair.get()); @@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { } } - for (auto &var : graph->dep_vars_) { + for (auto &var : graph->Get("dep_vars")) { insert_pending_var(var.get()); } - for (auto &op : graph->ops_) { + for (auto &op : graph->Get("ops")) { if (op->Inputs().empty()) { ready_ops.insert(op.get()); } else { diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 331aa9d2b5..20fa432a8b 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -29,7 +29,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Build(const ProgramDesc& program) const override { + std::unique_ptr Build(const ProgramDesc& program) const override { auto graph = builder_->Build(program); PADDLE_ENFORCE(IsValidGraph(graph.get())); return graph; @@ -39,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { return builder_->GetVarDeviceID(var_name); } - bool IsValidGraph(const SSAGraph* graph) const; + bool IsValidGraph(const Graph* graph) const; private: std::unique_ptr builder_; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 22a40ca4b2..412b0a6ff2 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -21,8 +21,8 @@ namespace framework { namespace details { template -static inline void IterAllVar(const SSAGraph &graph, Callback callback) { - for (auto &each : graph.vars_) { +static inline void IterAllVar(const Graph &graph, Callback callback) { + for (auto &each : graph.Get("vars")) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { callback(*pair2); @@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) { } } - for (auto &var : graph.dep_vars_) { + for (auto &var : graph.Get("dep_vars")) { callback(*var); } } -void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, +void GraphvizSSAGraphPrinter::Print(const Graph &graph, std::ostream &sout) const { size_t var_id = 0; std::unordered_map vars; @@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, }); size_t op_id = 0; - for (auto &op : graph.ops_) { + for (auto &op : graph.Get("ops")) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 09b0333ef2..da98685a21 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -25,12 +25,12 @@ struct SSAGraph; class SSAGraphPrinter { public: virtual ~SSAGraphPrinter() {} - virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0; + virtual void Print(const Graph& graph, std::ostream& sout) const = 0; }; class GraphvizSSAGraphPrinter : public SSAGraphPrinter { public: - void Print(const SSAGraph& graph, std::ostream& sout) const override; + void Print(const Graph& graph, std::ostream& sout) const override; }; class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { @@ -50,7 +50,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Build(const ProgramDesc& program) const override { + std::unique_ptr Build(const ProgramDesc& program) const override { auto graph = builder_->Build(program); printer_->Print(*graph, stream_ref_); return graph; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index f1de4d740d..21b9fa943e 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -26,71 +26,45 @@ limitations under the License. */ namespace paddle { namespace framework { -class Graph; - -template -struct AnyAttr { - public: - explicit AnyAttr(AttrType* attr) : attr_(attr) {} - - AttrType& Get() { return *boost::any_cast(attr_); } - - private: - friend Graph; - - AttrType* Release() { - released_ = true; - return boost::any_cast(attr_); - } - - void Delete() { - if (!released_) { - delete boost::any_cast(attr_); - } - } - - bool released_ = false; - boost::any attr_; -}; - class Graph { public: virtual ~Graph() { - for (auto& attr : attrs) { - attr_dels[attr.first](); + for (auto& attr : attrs_) { + attr_dels_[attr.first](); } - attrs.clear(); - attr_dels.clear(); + attrs_.clear(); + attr_dels_.clear(); } template - AttrType& Get(const std::string& attr_name) { - return boost::any_cast>(attrs[attr_name]).Get(); + AttrType& Get(const std::string& attr_name) const { + return *boost::any_cast(attrs_.at(attr_name)); } template void Set(const std::string& attr_name, AttrType* attr) { - AnyAttr any_attr = AnyAttr(attr); - attrs[attr_name] = any_attr; - attr_dels[attr_name] = [&any_attr]() { any_attr.Delete(); }; + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; } template AttrType* Erase(const std::string& attr_name) { - AnyAttr attr_type = - boost::any_cast>(attrs[attr_name]); - attrs.erase(attr_name); - attr_dels.erase(attr_name); - return attr_type.Release(); + AttrType* attr = boost::any_cast(attrs_[attr_name]); + attrs_.erase(attr_name); + attr_dels_.erase(attr_name); + return attr; } std::vector inputs; std::vector outputs; std::vector> nodes; - std::map attrs; - std::map> attr_dels; private: + std::map attrs_; + std::map> attr_dels_; }; } // namespace framework diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 9a72e1baa3..3db2d9cdc4 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/details/ssa_graph.h" + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -131,9 +133,16 @@ ParallelExecutor::ParallelExecutor( } builder_ = builder_factory.Create(); + std::unique_ptr graph = builder_->Build(main_program); + + std::unique_ptr ssa_graph(new details::SSAGraph); + ssa_graph->vars_ = std::move(graph->Get("vars")); + ssa_graph->ops_ = std::move(graph->Get("ops")); + ssa_graph->dep_vars_ = + std::move(graph->Get("dep_vars")); + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, - builder_->Build(main_program))); + exec_strategy, member_->local_scopes_, places, std::move(ssa_graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), From 7231ef6b68334ef095b643b565a8b2e52806c150 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 11 Jul 2018 21:43:42 +0800 Subject: [PATCH 07/22] tmp --- paddle/fluid/framework/ir/graph.h | 6 ++-- paddle/fluid/framework/ir/node.h | 46 +++++++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 21b9fa943e..72602840fc 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -58,9 +58,9 @@ class Graph { return attr; } - std::vector inputs; - std::vector outputs; - std::vector> nodes; + std::vector inputs; + std::vector outputs; + std::vector> nodes; private: std::map attrs_; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 9a280afb3b..0fd8048390 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include @@ -23,13 +24,23 @@ limitations under the License. */ namespace paddle { namespace framework { +namespace ir { class Node { public: enum class Type { kNone = -1, kOperation, kVariable }; - Node() {} - virtual ~Node() {} + Node(const std::string& name, Type type) : name_(name), type_(type) {} + + virtual ~Node() { + for (auto& attr : attrs_) { + if (attr_dels_.find(attr.first) != attr_dels_.end()) { + attr_dels_[attr.first](); + } + } + attr_dels_.clear(); + attrs_.clear(); + } int64_t ID() const { return id_; } @@ -43,17 +54,42 @@ class Node { Type NodeType() const { return type_; } - std::vector inputs; - std::vector outputs; + template + void Set(const std::string& name, AttrType attr) { + attrs_[name] = attr; + } + + template + void Set(const std::string& name, AttrType* attr, + std::function attr_del) { + attrs_[name] = attr; + attr_dels_[name] = attr_del; + } + + std::vector inputs; + std::vector outputs; protected: - std::map> attrs_; + std::map attrs_; + std::map> attr_dels_; int64_t id_ = 0; std::string name_; Type type_; + private: DISABLE_COPY_AND_ASSIGN(Node); }; +class Variable : public Node { + public: + explicit Variable(const std::string& name) : Node(name, Type::kVariable) {} +}; + +class Operation : public Node { + public: + explicit Operation(const std::string& name) : Node(name, Type::kOperation) {} +}; + +} // namespace ir } // namespace framework } // namespace paddle From af79b192077a6485fc90be4112f80acce8fb748b Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 12 Jul 2018 09:47:37 +0800 Subject: [PATCH 08/22] add a simple program to graph --- paddle/fluid/framework/CMakeLists.txt | 2 +- .../framework/details/multi_devices_graph_builder.cc | 4 ++-- .../framework/details/multi_devices_graph_builder.h | 2 +- paddle/fluid/framework/details/ssa_graph_builder.h | 2 +- paddle/fluid/framework/details/ssa_graph_checker.h | 8 ++++---- paddle/fluid/framework/details/ssa_graph_printer.h | 8 ++++---- paddle/fluid/framework/ir/graph.cc | 9 ++++++++- paddle/fluid/framework/ir/graph.h | 8 ++++++++ paddle/fluid/framework/ir/pass.h | 6 ------ paddle/fluid/framework/parallel_executor.cc | 2 +- 10 files changed, 30 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index bae8f51bcf..de06c860f5 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -94,7 +94,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 9be4963c91..0a95370419 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -168,8 +168,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( } std::unique_ptr MultiDevSSAGraphBuilder::Build( - const ProgramDesc &program) const { - std::unique_ptr graph(new Graph); + std::unique_ptr graph) const { + const ProgramDesc &program = graph->Program(); for (auto *var : program.Block(0).AllVars()) { all_vars_.emplace(var->Name(), var); } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index b9504665d0..248ea8ea62 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const BuildStrategy &strategy); #endif - std::unique_ptr Build(const ProgramDesc &program) const override; + std::unique_ptr Build(std::unique_ptr graph) const override; int GetVarDeviceID(const std::string &varname) const override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 56c3077cb3..4fbf036241 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -38,7 +38,7 @@ class SSAGraphBuilder { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + virtual std::unique_ptr Build(std::unique_ptr graph) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 20fa432a8b..7078b778be 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -29,10 +29,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Build(const ProgramDesc& program) const override { - auto graph = builder_->Build(program); - PADDLE_ENFORCE(IsValidGraph(graph.get())); - return graph; + std::unique_ptr Build(std::unique_ptr graph) const override { + auto new_graph = builder_->Build(std::move(graph)); + PADDLE_ENFORCE(IsValidGraph(new_graph.get())); + return new_graph; } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index da98685a21..0bd2b10eda 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Build(const ProgramDesc& program) const override { - auto graph = builder_->Build(program); - printer_->Print(*graph, stream_ref_); - return graph; + std::unique_ptr Build(std::unique_ptr graph) const override { + auto new_graph = builder_->Build(std::move(graph)); + printer_->Print(*new_graph, stream_ref_); + return new_graph; } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index b5c5ba7c14..28ad4efc71 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -15,5 +15,12 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +std::unique_ptr ProgramToGraph(const ProgramDesc &program) { + std::unique_ptr graph(new Graph(program)); + return std::move(graph); +} + +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 72602840fc..e83cb5a82a 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" @@ -28,6 +29,8 @@ namespace framework { class Graph { public: + explicit Graph(const ProgramDesc& program) : program_(program) {} + virtual ~Graph() { for (auto& attr : attrs_) { attr_dels_[attr.first](); @@ -36,6 +39,8 @@ class Graph { attr_dels_.clear(); } + const ProgramDesc& Program() const { return program_; } + template AttrType& Get(const std::string& attr_name) const { return *boost::any_cast(attrs_.at(attr_name)); @@ -63,9 +68,12 @@ class Graph { std::vector> nodes; private: + const ProgramDesc& program_; std::map attrs_; std::map> attr_dels_; }; +std::unique_ptr ProgramToGraph(const ProgramDesc& program); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 087ebb8709..2fc26c053f 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -30,11 +30,5 @@ class Pass { } }; -std::unique_ptr ProgramToGraph(const ProgramDesc& program) { - std::unique_ptr g(new Graph); - - return std::move(g); -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3db2d9cdc4..42bbd2b3ff 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -133,7 +133,7 @@ ParallelExecutor::ParallelExecutor( } builder_ = builder_factory.Create(); - std::unique_ptr graph = builder_->Build(main_program); + std::unique_ptr graph = builder_->Build(ProgramToGraph(main_program)); std::unique_ptr ssa_graph(new details::SSAGraph); ssa_graph->vars_ = std::move(graph->Get("vars")); From 9605fcd124ae6a3cdad171d2d61107e9cabe4c2f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 12 Jul 2018 11:26:08 +0800 Subject: [PATCH 09/22] all graphs --- paddle/fluid/framework/details/ssa_graph_checker.h | 1 - paddle/fluid/framework/details/ssa_graph_printer.h | 2 +- .../details/threaded_ssa_graph_executor.cc | 13 +++++++------ .../framework/details/threaded_ssa_graph_executor.h | 5 +++-- paddle/fluid/framework/parallel_executor.cc | 8 +------- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 7078b778be..2c8b2e13c5 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -21,7 +21,6 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; class SSAGraghBuilderWithChecker : public SSAGraphBuilder { public: diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 0bd2b10eda..35f2a1b4f0 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -21,7 +21,7 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; + class SSAGraphPrinter { public: virtual ~SSAGraphPrinter() {} diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 07097c7e75..ed8e38039e 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -14,13 +14,14 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + namespace paddle { namespace framework { namespace details { ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, - const std::vector &places, - std::unique_ptr &&graph) + const std::vector &places, std::unique_ptr &&graph) : graph_(std::move(graph)), pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) : nullptr), @@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( std::unordered_set delayed_ops; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); } } } - for (auto &var : graph_->dep_vars_) { + for (auto &var : graph_->Get("dep_vars")) { InsertPendingVar(&pending_vars, &ready_vars, var.get()); } - for (auto &op : graph_->ops_) { + for (auto &op : graph_->Get("ops")) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -158,7 +159,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 09973b7a72..7d0aaf2ddc 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { @@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph); + std::unique_ptr &&graph); // Run a SSAGraph by a thread pool // Use topological sort algorithm @@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { details::OpHandleBase *op); private: - std::unique_ptr graph_; + std::unique_ptr graph_; std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; std::vector places_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 42bbd2b3ff..d30aba07a0 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -135,14 +135,8 @@ ParallelExecutor::ParallelExecutor( builder_ = builder_factory.Create(); std::unique_ptr graph = builder_->Build(ProgramToGraph(main_program)); - std::unique_ptr ssa_graph(new details::SSAGraph); - ssa_graph->vars_ = std::move(graph->Get("vars")); - ssa_graph->ops_ = std::move(graph->Get("ops")); - ssa_graph->dep_vars_ = - std::move(graph->Get("dep_vars")); - member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, std::move(ssa_graph))); + exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), From 37e514432b0d6453906aaaacbf54d8ae51ebddc4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 12 Jul 2018 15:30:21 +0800 Subject: [PATCH 10/22] op compose node and update nodes. --- .../framework/details/all_reduce_op_handle.cc | 13 +++- .../framework/details/all_reduce_op_handle.h | 4 +- .../framework/details/broadcast_op_handle.h | 11 ++- .../details/broadcast_op_handle_test.cc | 32 +++++--- .../details/computation_op_handle.cc | 11 +-- .../framework/details/computation_op_handle.h | 2 +- .../details/data_balance_op_handle.cc | 8 +- .../details/data_balance_op_handle.h | 4 +- .../framework/details/fetch_op_handle.cc | 13 ++-- .../fluid/framework/details/fetch_op_handle.h | 2 +- .../framework/details/fuse_vars_op_handle.h | 6 +- .../framework/details/gather_op_handle.cc | 5 +- .../framework/details/gather_op_handle.h | 2 +- .../details/gather_op_handle_test.cc | 22 ++++-- .../details/multi_devices_graph_builder.cc | 73 ++++++++++++------- .../fluid/framework/details/op_handle_base.cc | 12 +-- .../fluid/framework/details/op_handle_base.h | 6 +- .../framework/details/reduce_op_handle.h | 11 ++- .../details/reduce_op_handle_test.cc | 26 ++++--- .../fluid/framework/details/rpc_op_handle.cc | 9 ++- .../fluid/framework/details/rpc_op_handle.h | 5 +- .../details/scale_loss_grad_op_handle.cc | 8 +- .../details/scale_loss_grad_op_handle.h | 3 +- .../framework/details/ssa_graph_builder.cc | 18 +++-- .../framework/details/ssa_graph_checker.cc | 4 +- .../details/threaded_ssa_graph_executor.cc | 19 +++-- .../details/threaded_ssa_graph_executor.h | 1 + paddle/fluid/framework/details/var_handle.h | 49 ++++++++++++- paddle/fluid/framework/ir/node.h | 36 ++------- 29 files changed, 262 insertions(+), 153 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index b335d3a0d3..700c73c745 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -23,10 +23,14 @@ namespace framework { namespace details { #ifdef PADDLE_WITH_CUDA -AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(ctxs) { if (nccl_ctxs_) { for (auto &p : places_) { this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); @@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, } } #else -AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif void AllReduceOpHandle::RunImpl() { diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index fdd250b0d3..f6ef3a1367 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -30,11 +30,11 @@ namespace details { struct AllReduceOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA - AllReduceOpHandle(const std::vector &local_scopes, + AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs); #else - AllReduceOpHandle(const std::vector &local_scopes, + AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 8036f756b6..fe4e733e43 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -35,10 +35,13 @@ namespace details { struct BroadcastOpHandle : public OpHandleBase { public: #ifdef PADDLE_WITH_CUDA - BroadcastOpHandle(const std::vector &local_scopes, + BroadcastOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *nccl_ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(nccl_ctxs) { if (nccl_ctxs_) { for (auto &p_ctx : nccl_ctxs_->contexts_) { dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); @@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase { } } #else - BroadcastOpHandle(const std::vector &local_scopes, + BroadcastOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index c6e923ef77..90ee3f7d93 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,48 +96,56 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); + std::unique_ptr n(new ir::Node(ir::Node::Type::kOperation)); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, + nccl_ctxs_.get())); #else PADDLE_THROW("CUDA is not support."); #endif } else { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, + nccl_ctxs_.get())); #else - op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); + op_handle_.reset( + new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_)); #endif } - auto* in_var_handle = - new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); + std::unique_ptr v(new ir::Node(ir::Node::Type::kVariable)); + auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", + gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); // add dummy var - vars_.emplace_back(new DummyVarHandle()); + + std::unique_ptr v2(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); - dummy_var_handle->generated_op_ = nullptr; + dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(dummy_var_handle); for (size_t j = 0; j < gpu_list_.size(); ++j) { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); + std::unique_ptr v3(new ir::Node(ir::Node::Type::kVariable)); + VarHandle* out_var_handle = + new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); } // add dummy var - vars_.emplace_back(new DummyVarHandle()); + std::unique_ptr v4(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); - out_dummy_var_handle->generated_op_ = nullptr; + out_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddOutput(out_dummy_var_handle); } diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index df05bb0633..16ad30d491 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -19,9 +19,10 @@ namespace paddle { namespace framework { namespace details { -ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, - platform::Place place) - : op_(framework::OpRegistry::CreateOp(op_desc)), +ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, + Scope *scope, platform::Place place) + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(op_desc)), scope_(scope), place_(place) {} @@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool need_wait = - in_var && in_var->generated_op_ && - in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; + in_var && in_var->GeneratedOp() && + in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_]; return need_wait; } diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index f048f973fd..9ca1d927b8 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,7 +28,7 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope, platform::Place place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 68896c8ac1..525d243224 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -22,10 +22,10 @@ namespace details { #ifdef PADDLE_WITH_CUDA DataBalanceOpHandle::DataBalanceOpHandle( - const std::vector &local_scopes, + ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs) - : local_scopes_(local_scopes), places_(places) { + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) { if (ctxs) { for (auto &p : places_) { this->dev_ctxes_[p] = ctxs->DevCtx(p); @@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle( } #else DataBalanceOpHandle::DataBalanceOpHandle( - const std::vector &local_scopes, + ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string DataBalanceOpHandle::Name() const { return "data balance"; } diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h index 76a407e361..0462fb6ec7 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.h +++ b/paddle/fluid/framework/details/data_balance_op_handle.h @@ -30,11 +30,11 @@ namespace details { struct DataBalanceOpHandle : public OpHandleBase { public: #ifdef PADDLE_WITH_CUDA - DataBalanceOpHandle(const std::vector &local_scopes, + DataBalanceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs); #else - DataBalanceOpHandle(const std::vector &local_scopes, + DataBalanceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); #endif diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index d646c94460..fe18b2060c 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -21,13 +21,16 @@ namespace paddle { namespace framework { namespace details { -FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset, +FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, std::vector *local_scopes) - : data_(data), offset_(offset), local_scopes_(local_scopes) {} + : OpHandleBase(node), + data_(data), + offset_(offset), + local_scopes_(local_scopes) {} FetchOpHandle::~FetchOpHandle() { for (auto *input_var : inputs_) { - input_var->pending_ops_.erase(this); + input_var->RemoveOutput(this, this->Node()); } } @@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() { void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place); for (auto *input : inputs_) { - if (input->generated_op_) { - input->generated_op_->RecordWaitEventOnCtx(cpu_ctx); + if (input->GeneratedOp()) { + input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx); } } } diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index e09bdd1d33..6ce42f92d7 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -28,7 +28,7 @@ namespace details { struct FetchOpHandle : public OpHandleBase { public: - FetchOpHandle(FeedFetchList *data, size_t offset, + FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, std::vector *local_scopes); ~FetchOpHandle(); diff --git a/paddle/fluid/framework/details/fuse_vars_op_handle.h b/paddle/fluid/framework/details/fuse_vars_op_handle.h index 140fb5bb49..3f360c510a 100644 --- a/paddle/fluid/framework/details/fuse_vars_op_handle.h +++ b/paddle/fluid/framework/details/fuse_vars_op_handle.h @@ -30,10 +30,12 @@ namespace details { struct FuseVarsOpHandle : public OpHandleBase { public: - FuseVarsOpHandle(Scope *local_scope, const platform::Place &place, + FuseVarsOpHandle(ir::Node *node, Scope *local_scope, + const platform::Place &place, const std::unordered_map &inputs_numel, const std::type_index &var_type) - : local_scope_(local_scope), + : OpHandleBase(node), + local_scope_(local_scope), place_(place), inputs_numel_(inputs_numel), type_(var_type) { diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 2be0230456..9aae19fc73 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -20,9 +20,10 @@ namespace paddle { namespace framework { namespace details { -GatherOpHandle::GatherOpHandle(const std::vector &local_scopes, +GatherOpHandle::GatherOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} void GatherOpHandle::RunImpl() { if (places_.size() == 1) return; diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index d11ef8556a..d9afbc6547 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -30,7 +30,7 @@ namespace details { struct GatherOpHandle : public OpHandleBase { public: - GatherOpHandle(const std::vector &local_scopes, + GatherOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); std::string Name() const override; diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 3cce2cc164..5b11f8cdc7 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -70,6 +70,7 @@ struct TestGatherOpHandle { } void InitGatherOp(size_t input_scope_idx) { + std::vector> nodes; for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); Scope& local_scope = local_scopes_.back()->NewScope(); @@ -81,30 +82,37 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); + nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + op_handle_.reset( + new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto* in_var_handle = + new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); } // add dummy var - vars_.emplace_back(new DummyVarHandle()); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); - in_dummy_var_handle->generated_op_ = nullptr; + in_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(in_dummy_var_handle); // add output - auto* out_var_handle = - new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, + "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - vars_.emplace_back(new DummyVarHandle()); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); op_handle_->AddOutput(dummy_var_handle); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0a95370419..cb2ab90516 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -328,12 +328,16 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_); + auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(), + local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); + auto *op_handle = + new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); + auto *in = result->Get("vars").at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); @@ -341,8 +345,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto &vars = result->Get("vars").at(i).at(p_name); - auto *out_var = new VarHandle(vars.size(), i, p_name, p); + auto *out_var = + new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } @@ -351,19 +357,21 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const { - result->Get("ops").emplace_back( - new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + result->Get("ops").emplace_back(new ComputationOpHandle( + result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, op, dev_id); } void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back( - new AllReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->nodes.back().get(), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -375,7 +383,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - auto var = new VarHandle(vars.size(), i, og, p); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -383,12 +392,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back( - new DataBalanceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->nodes.back().get(), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -398,7 +408,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - auto var = new VarHandle(vars.size(), i, d_name, p); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = + new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -452,10 +464,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif - - auto *op_handle = - new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], - places_[i], communication_dev_ctx); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + auto *op_handle = new ScaleLossGradOpHandle( + result->nodes.back().get(), local_scopes_.size(), local_scopes_[i], + places_[i], communication_dev_ctx); result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -475,8 +487,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); result->Get("ops").emplace_back( - new ComputationOpHandle(op, s, p)); + new ComputationOpHandle(result->nodes.back().get(), op, s, p)); CreateOpHandleIOs(result, op, scope_idx); } } @@ -484,12 +497,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new ReduceOpHandle( + result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back( - new ReduceOpHandle(local_scopes_, places_)); + new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -502,7 +516,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id, + og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); return var; @@ -514,7 +530,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle(); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *dep_var = new DummyVarHandle(result->nodes.back().get()); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); @@ -587,8 +604,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", op.Type()); - result->Get("ops").emplace_back(new RPCOpHandle( - op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id])); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + result->Get("ops").emplace_back( + new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id], + op.Type(), places_[op_dev_id])); if (op.Type() == "send_barrier") { ConnectOp(result, result->Get("ops").back().get(), "send"); diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index d80bdcf15d..ee9f9184da 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void OpHandleBase::AddInput(VarHandleBase *in) { this->inputs_.emplace_back(in); - in->pending_ops_.insert(this); + node_->inputs.push_back(in->Node()); + in->AddOutput(this, this->Node()); } void OpHandleBase::AddOutput(VarHandleBase *out) { outputs_.emplace_back(out); - out->generated_op_ = this; + node_->outputs.push_back(out->Node()); + out->AddInput(this, this->Node()); } void OpHandleBase::WaitInputVarGenerated() { for (auto in_var : inputs_) { if (NeedWait(in_var)) { for (auto &pair : dev_ctxes_) { - in_var->generated_op_->RecordWaitEventOnCtx(pair.second); + in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second); } } } @@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { for (auto *in : inputs_) { if (NeedWait(in)) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]); + in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]); } } } @@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const { } bool OpHandleBase::NeedWait(VarHandleBase *in_var) { - return in_var && in_var->generated_op_; + return in_var && in_var->GeneratedOp(); } void OpHandleBase::RunAndRecordEvent(const std::function &callback) { diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 6aec178831..368a153711 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/details/var_handle.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/macros.h" @@ -28,7 +29,7 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; class OpHandleBase { public: - OpHandleBase() {} + explicit OpHandleBase(ir::Node *node) : node_(node) {} virtual ~OpHandleBase(); @@ -82,6 +83,8 @@ class OpHandleBase { size_t NoDummyInputSize() const; + ir::Node *Node() { return node_; } + protected: void RunAndRecordEvent(const std::function &callback); @@ -90,6 +93,7 @@ class OpHandleBase { virtual void RunImpl() = 0; + ir::Node *node_; std::vector inputs_; std::vector outputs_; std::map dev_ctxes_; diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 4d14334cdf..a6289b055f 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA const platform::NCCLContextMap *nccl_ctxs_; - ReduceOpHandle(const std::vector &local_scopes, + ReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *nccl_ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(nccl_ctxs) { if (nccl_ctxs_) { for (auto &p_ctx : nccl_ctxs_->contexts_) { dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); @@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase { } } #else - ReduceOpHandle(const std::vector &local_scopes, + ReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index ffdd7c14eb..d029dd9e15 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -84,6 +84,7 @@ struct TestReduceOpHandle { } void InitReduceOp(size_t out_scope_idx) { + std::vector> nodes; // init scope for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); @@ -96,19 +97,21 @@ struct TestReduceOpHandle { } param_scopes_[out_scope_idx]->Var("out"); + nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, + gpu_list_, nccl_ctxs_.get())); #else PADDLE_THROW("CUDA is not support."); #endif } else { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, + gpu_list_, nccl_ctxs_.get())); #else - op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_)); + op_handle_.reset( + new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); #endif } @@ -118,8 +121,10 @@ struct TestReduceOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); - in_var_handle->generated_op_ = nullptr; + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *in_var_handle = + new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); + in_var_handle->ClearGeneratedOp(); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); } @@ -128,12 +133,13 @@ struct TestReduceOpHandle { vars_.emplace_back(new DummyVarHandle()); DummyVarHandle *in_dummy_var_handle = static_cast(vars_.back().get()); - in_dummy_var_handle->generated_op_ = nullptr; + in_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(in_dummy_var_handle); // add output - auto *out_var_handle = - new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, + "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 586465f99f..924ff4d118 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -18,10 +18,11 @@ namespace paddle { namespace framework { namespace details { -RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, +RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, const Scope *local_scope, const std::string &name, const platform::Place &place) - : op_(framework::OpRegistry::CreateOp(op_desc)), + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(op_desc)), local_scope_(local_scope), name_(name), place_(place) {} @@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() { if (in->DebugString() == "dummy") { // HACK continue; } - if (in->generated_op_) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + if (in->GeneratedOp()) { + in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]); } } auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h index ae38c7fe19..7f99cdeacf 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.h +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -28,8 +28,9 @@ namespace framework { namespace details { struct RPCOpHandle : public OpHandleBase { - RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const std::string& name, const platform::Place& place); + RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc, + const Scope* local_scope, const std::string& name, + const platform::Place& place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index d9c387e79d..609e185819 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -19,10 +19,14 @@ namespace paddle { namespace framework { namespace details { -ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, +ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, + Scope *scope, platform::Place place, platform::DeviceContext *dev_ctx) - : coeff_(static_cast(1.0 / num_dev)), scope_(scope), place_(place) { + : OpHandleBase(node), + coeff_(static_cast(1.0 / num_dev)), + scope_(scope), + place_(place) { dev_ctxes_[place_] = dev_ctx; } diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index d93d599d46..523b55724c 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -25,7 +25,8 @@ namespace framework { namespace details { struct ScaleLossGradOpHandle : public OpHandleBase { - ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, + ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, + platform::Place place, platform::DeviceContext *context); ~ScaleLossGradOpHandle() final; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 2508ed0296..846f98ddfa 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { auto it_old = name_pair.second.rbegin(); ++it_old; for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - auto *write_op = (*it_new)->generated_op_; - auto &read_ops = (*it_old)->pending_ops_; + OpHandleBase *write_op = (*it_new)->GeneratedOp(); + const auto &read_ops = (*it_old)->PendingOps(); for (auto *read_op : read_ops) { // Manually add a dependency var from read_op to write_op; @@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - auto *dep_var = new DummyVarHandle(); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *dep_var = new DummyVarHandle(graph->nodes.back().get()); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -54,7 +55,9 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { - var = new VarHandle(0, place_offset, each_var_name, place); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + var = new VarHandle(graph->nodes.back().get(), 0, place_offset, + each_var_name, place); var_holder.emplace_back(var); } else { var = var_holder.rbegin()->get(); @@ -68,7 +71,9 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, size_t place_offset) { auto &vars = graph->Get("vars")[place_offset][each_var_name]; size_t version = vars.size(); - auto var = new VarHandle(version, place_offset, each_var_name, place); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = new VarHandle(graph->nodes.back().get(), version, place_offset, + each_var_name, place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -80,7 +85,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle(); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get()); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index c01334ca06..6a211f52bb 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -28,7 +28,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { auto insert_pending_var = [&](VarHandleBase *var) { pending_vars.insert(var); - if (var->generated_op_ == nullptr) { + if (var->GeneratedOp() == nullptr) { ready_vars.emplace(var); } }; @@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { + for (auto *op : ready_var->PendingOps()) { auto &deps = --pending_ops[op]; if (deps == 0) { ready_ops.insert(op); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ed8e38039e..9a2413118e 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -65,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; + std::vector> tmp_nodes; std::unordered_set> fetch_dependencies; FeedFetchList fetch_data(fetch_tensors.size()); - InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, - &pending_vars, &ready_vars, &fetch_data); + InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies, + &pending_ops, &pending_vars, &ready_vars, &fetch_data); auto run_all_ops = [&](std::unordered_set &set) { for (auto *op : set) { @@ -126,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Find the ready_ops after the ready_var. for (auto ready_var : cur_ready_vars) { pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { + for (auto *op : ready_var->PendingOps()) { auto &deps = pending_ops[op]; --deps; if (deps == 0) { @@ -152,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, + std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, @@ -170,7 +172,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_); + + ir::Node *fetch_n = new ir::Node(ir::Node::Type::kOperation); + auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); + temp_nodes->emplace_back(fetch_n); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -181,9 +186,11 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - auto *fetch_dummy = new DummyVarHandle(); + ir::Node *dummy_n = new ir::Node(ir::Node::Type::kVariable); + auto *fetch_dummy = new DummyVarHandle(dummy_n); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); + temp_nodes->emplace_back(dummy_n); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingOp(pending_ops, op); } @@ -199,7 +206,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( std::unordered_set *pending_vars, BlockingQueue *ready_vars, VarHandleBase *var) const { pending_vars->insert(var); - if (var->generated_op_ == nullptr) { + if (var->GeneratedOp() == nullptr) { ready_vars->Push(var); } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 7d0aaf2ddc..bf7c0a367a 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -72,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { void InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, + std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index c62f9a9d08..8bd3db9203 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once + +#include #include #include #include @@ -30,15 +32,51 @@ class OpHandleBase; // A variable can only be generated by a single operator. i.e. // This is a single assignment graph. struct VarHandleBase { + explicit VarHandleBase(ir::Node* node) : node_(node) {} + virtual ~VarHandleBase(); + virtual std::string DebugString() const = 0; + void AddInput(OpHandleBase* in, ir::Node* node) { + node_->inputs.clear(); + node_->inputs.push_back(node); + generated_op_ = in; + } + + void AddOutput(OpHandleBase* out, ir::Node* node) { + if (pending_ops_.find(out) == pending_ops_.end()) { + pending_ops_.insert(out); + node_->outputs.push_back(node); + } + } + + void RemoveOutput(OpHandleBase* out, ir::Node* node) { + pending_ops_.erase(out); + std::remove(node_->outputs.begin(), node_->outputs.end(), node); + } + + void ClearGeneratedOp() { + generated_op_ = nullptr; + node_->inputs.clear(); + } + + OpHandleBase* GeneratedOp() { return generated_op_; } + + const std::unordered_set& PendingOps() const { + return pending_ops_; + } + + ir::Node* Node() { return node_; } + + protected: // The operator who generate this variable. nullptr if the variable // is a root node. OpHandleBase* generated_op_{nullptr}; // Operators which depend on this variable ready. std::unordered_set pending_ops_; + ir::Node* node_; }; // VarHandle is actually a single version of Runtime Variable. @@ -47,11 +85,14 @@ struct VarHandleBase { // // NOTE: runtime variables have place. struct VarHandle : public VarHandleBase { + explicit VarHandle(ir::Node* node) : VarHandleBase(node) {} + std::string DebugString() const override; - VarHandle(size_t version, size_t scope_index, std::string name, - platform::Place place) - : version_(version), + VarHandle(ir::Node* node, size_t version, size_t scope_index, + std::string name, platform::Place place) + : VarHandleBase(node), + version_(version), scope_idx_(scope_index), name_(std::move(name)), place_(std::move(place)) {} @@ -71,6 +112,8 @@ struct VarHandle : public VarHandleBase { // Dummy Variable. It is used to represent dependencies between operators struct DummyVarHandle : public VarHandleBase { + explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {} + std::string DebugString() const override; }; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 0fd8048390..94ace92953 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,10 +14,12 @@ limitations under the License. */ #pragma once +#include #include #include #include #include +#include #include #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/variant.h" @@ -30,10 +32,10 @@ class Node { public: enum class Type { kNone = -1, kOperation, kVariable }; - Node(const std::string& name, Type type) : name_(name), type_(type) {} + explicit Node(Type type) : type_(type) {} virtual ~Node() { - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { if (attr_dels_.find(attr.first) != attr_dels_.end()) { attr_dels_[attr.first](); } @@ -42,54 +44,32 @@ class Node { attrs_.clear(); } - int64_t ID() const { return id_; } - - std::string Name() const { return name_; } - - virtual std::string ToString() const { - return Name() + "(" + std::to_string(ID()) + ")"; - } - - virtual std::string DebugString() const = 0; - Type NodeType() const { return type_; } template - void Set(const std::string& name, AttrType attr) { + void Set(const std::string &name, AttrType attr) { attrs_[name] = attr; } template - void Set(const std::string& name, AttrType* attr, + void Set(const std::string &name, AttrType *attr, std::function attr_del) { attrs_[name] = attr; attr_dels_[name] = attr_del; } - std::vector inputs; - std::vector outputs; + std::vector inputs; + std::vector outputs; protected: std::map attrs_; std::map> attr_dels_; - int64_t id_ = 0; - std::string name_; Type type_; private: DISABLE_COPY_AND_ASSIGN(Node); }; -class Variable : public Node { - public: - explicit Variable(const std::string& name) : Node(name, Type::kVariable) {} -}; - -class Operation : public Node { - public: - explicit Operation(const std::string& name) : Node(name, Type::kOperation) {} -}; - } // namespace ir } // namespace framework } // namespace paddle From 2fa8df1cafa2caf4d25d115390c3ca5705c370c4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 13 Jul 2018 14:49:28 +0800 Subject: [PATCH 11/22] separate graph building pass and graph-based pe builder --- .../details/broadcast_op_handle_test.cc | 10 +- .../details/gather_op_handle_test.cc | 10 +- .../details/multi_devices_graph_builder.cc | 260 ++++++++++-------- .../details/multi_devices_graph_builder.h | 25 +- .../details/reduce_op_handle_test.cc | 6 +- .../framework/details/ssa_graph_builder.cc | 26 +- .../framework/details/ssa_graph_builder.h | 12 +- .../framework/details/ssa_graph_checker.h | 6 +- .../framework/details/ssa_graph_printer.h | 6 +- paddle/fluid/framework/ir/graph.cc | 33 +++ paddle/fluid/framework/ir/graph.h | 21 +- paddle/fluid/framework/ir/node.h | 25 +- paddle/fluid/framework/ir/pass.h | 8 +- paddle/fluid/framework/parallel_executor.cc | 5 +- python/paddle/fluid/parallel_executor.py | 2 + 15 files changed, 273 insertions(+), 182 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 90ee3f7d93..1609b5965c 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,7 +96,7 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n(new ir::Node(ir::Node::Type::kOperation)); + std::unique_ptr n(new ir::Node()); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -114,7 +114,7 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v(new ir::Node()); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -122,7 +122,7 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v2(new ir::Node()); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -133,7 +133,7 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v3(new ir::Node()); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -141,7 +141,7 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4(new ir::Node(ir::Node::Type::kVariable)); + std::unique_ptr v4(new ir::Node()); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 5b11f8cdc7..f80cabf501 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,13 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + nodes.emplace_back(new ir::Node()); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +96,7 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +104,14 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cb2ab90516..d66bc40090 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -67,30 +67,31 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, size_t place_id) const { auto p = places_[place_id]; auto *op_handle = result->Get("ops").back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); - for (auto &each_var_name : op.InputArgumentNames()) { - VarHandle *var = - CreateOrGetLatestVarHandle(result, each_var_name, p, place_id); + for (ir::Node *input : node->inputs) { + VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id); op_handle->AddInput(var); } - for (auto &each_var_name : op.OutputArgumentNames()) { - CreateOpOutput(result, op_handle, each_var_name, p, place_id); + for (ir::Node *output : node->outputs) { + CreateOpOutput(result, op_handle, output, p, place_id); } } std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( - const ProgramDesc &program) const { + const std::vector> &nodes) const { std::vector send_vars; // since parameters are all in block 0, // it's enough to only scan send ops in block 0 - for (auto *op : program.Block(0).AllOps()) { + for (auto &node : nodes) { + if (!node->Op()) continue; + OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find send op, // instead of the the hard code string if (op->Type() == "send") { @@ -104,9 +105,11 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( } std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( - const ProgramDesc &program) const { + const std::vector> &nodes) const { std::vector recv_vars; - for (auto *op : program.Block(0).AllOps()) { + for (auto &node : nodes) { + if (!node->Op()) continue; + OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find recv op, // instead of the hard code string if (op->Type() == "recv") { @@ -120,7 +123,7 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( } bool MultiDevSSAGraphBuilder::IsDistTrainOp( - const OpDesc &op, const std::vector &send_vars, + ir::Node *node, const std::vector &send_vars, const std::vector &recv_vars) const { if (send_vars.size() == 0 || recv_vars.size() == 0) { return false; @@ -143,8 +146,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( return false; }; - return checker(op.OutputArgumentNames(), send_vars) || - checker(op.InputArgumentNames(), recv_vars); + std::vector input_var_names; + std::vector output_var_names; + for (ir::Node *input : node->inputs) { + input_var_names.push_back(input->Var()->Name()); + } + for (ir::Node *output : node->outputs) { + output_var_names.push_back(output->Var()->Name()); + } + + return checker(output_var_names, send_vars) || + checker(input_var_names, recv_vars); } size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( @@ -167,11 +179,16 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( return dev_id; } -std::unique_ptr MultiDevSSAGraphBuilder::Build( +std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unique_ptr graph) const { - const ProgramDesc &program = graph->Program(); - for (auto *var : program.Block(0).AllVars()) { - all_vars_.emplace(var->Name(), var); + auto nodes = std::move(graph->nodes); + graph->nodes.clear(); + LOG(ERROR) << "origin nodes count " << nodes.size(); + + for (auto &node : nodes) { + if (node->Var()) { + all_vars_.emplace(node->Var()->Name(), node->Var()); + } } Graph &result = *graph; @@ -181,10 +198,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.Set("vars", new GraphVars(places_.size())); result.Set("dep_vars", new GraphDepVars); result.Set("ops", new GraphOps); + // find send/recv vars so that we can place the distributed training // realted op in the place 0 - auto send_vars = FindDistTrainSendVars(program); - auto recv_vars = FindDistTrainRecvVars(program); + auto send_vars = FindDistTrainSendVars(nodes); + auto recv_vars = FindDistTrainRecvVars(nodes); std::vector> bcast_var_name_set; bcast_var_name_set.resize(places_.size()); @@ -192,14 +210,16 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( size_t cur_device_id = 0; bool is_forwarding = true; - for (auto *op : program.Block(0).AllOps()) { + // TODO(panyx0718): FIXME: nodes should be sorted by "program" order. + for (auto &node : nodes) { + if (!node->Op()) continue; if (boost::get( - op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { - CreateRPCOp(&result, *op); - } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { - CreateDistTrainOp(&result, *op); - } else if (IsScaleLossOp(*op)) { + CreateRPCOp(&result, node.get()); + } else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) { + CreateDistTrainOp(&result, node.get()); + } else if (IsScaleLossOp(node.get())) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != BuildStrategy::GradientScaleStrategy::kCustomized) { @@ -211,33 +231,35 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // the block. is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(*op); + int op_dev_id = GetOpDeviceID(node.get()); if (op_dev_id != -1) { // This op only runs on one specific device. - CreateComputationalOp(&result, *op, op_dev_id); - for (auto &var_name : op->OutputArgumentNames()) { - var_name_on_devices_.emplace(var_name, op_dev_id); + CreateComputationalOp(&result, node.get(), op_dev_id); + for (ir::Node *n : node->outputs) { + var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's // gradients. - if (op->Type() == "read" && strategy_.enable_data_balance_) { - op->SetAttr("throw_eof_exp", false); - CreateComputationalOps(&result, *op, places_.size()); - const auto &data_var_names = op->Output("Out"); + if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { + node->Op()->SetAttr("throw_eof_exp", false); + CreateComputationalOps(&result, node.get(), places_.size()); + // TODO(panyx0718): builder shouldn't depend on the out logic of + // a specific op. + const auto &data_var_names = node->Op()->Output("Out"); InsertDataBalanceOp(&result, data_var_names); } else { - CreateComputationalOps(&result, *op, places_.size()); + CreateComputationalOps(&result, node.get(), places_.size()); } if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. - if (static_cast(boost::get(op->GetAttr( + if (static_cast(boost::get(node->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward))) { try { - auto backward_vars = - boost::get>(op->GetNullableAttr( + auto backward_vars = boost::get>( + node->Op()->GetNullableAttr( OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); @@ -328,13 +350,12 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(), + auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = - new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_); + auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), + local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); @@ -345,33 +366,31 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto &vars = result->Get("vars").at(i).at(p_name); auto *out_var = - new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p); + new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } } void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, - const OpDesc &op, + ir::Node *node, int dev_id) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); - result->Get("ops").emplace_back(new ComputationOpHandle( - result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id])); - CreateOpHandleIOs(result, op, dev_id); + result->Get("ops").emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(), + local_scopes_[dev_id], places_[dev_id])); + CreateOpHandleIOs(result, node, dev_id); } void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new AllReduceOpHandle( - result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); + result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new AllReduceOpHandle( - result->nodes.back().get(), local_scopes_, places_)); + result->CreateOpNode(nullptr), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -383,8 +402,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p); + auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -392,13 +410,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); + result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->nodes.back().get(), local_scopes_, places_)); + result->CreateOpNode(nullptr), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -408,9 +425,8 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = - new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p); + auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i, + d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -429,17 +445,17 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } int op_role = boost::get( - op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); if (op_role != static_cast(framework::OpRole::kOptimize)) { return -1; } auto param_grad = boost::get>( - op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); int dev_id = GetVarDeviceID(param_grad[1]); @@ -464,9 +480,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); auto *op_handle = new ScaleLossGradOpHandle( - result->nodes.back().get(), local_scopes_.size(), local_scopes_[i], + result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); result->Get("ops").emplace_back(op_handle); @@ -476,34 +491,38 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i], - i); + // TODO(panyx0718): GradVarName(loss_var_name_) + const std::string grad_var_name = GradVarName(loss_var_name_); + auto &vars = result->Get("vars")[i][grad_var_name]; + size_t version = vars.size(); + auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i, + grad_var_name, places_[i]); + vars.emplace_back(var); + op_handle->AddOutput(var); } } void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, - const OpDesc &op, + ir::Node *node, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); - result->Get("ops").emplace_back( - new ComputationOpHandle(result->nodes.back().get(), op, s, p)); - CreateOpHandleIOs(result, op, scope_idx); + result->Get("ops").emplace_back(new ComputationOpHandle( + result->CreateOpNode(node->Op()), *node->Op(), s, p)); + CreateOpHandleIOs(result, node, scope_idx); } } VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new ReduceOpHandle( - result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); + result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back( - new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_)); + result->Get("ops").emplace_back(new ReduceOpHandle( + result->CreateOpNode(nullptr), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -516,8 +535,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id, + auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -530,8 +548,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto *dep_var = new DummyVarHandle(result->nodes.back().get()); + auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy")); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); @@ -540,22 +557,32 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, } void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, - const OpDesc &op) const { + ir::Node *node) const { int op_dev_id = -1; - if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + std::vector input_var_names; + std::vector output_var_names; + for (ir::Node *input : node->inputs) { + input_var_names.push_back(input->Var()->Name()); + } + for (ir::Node *output : node->outputs) { + output_var_names.push_back(output->Var()->Name()); + } + + if (node->Op()->Type() == "split_byref" || + node->Op()->Type() == "split_selected_rows") { + op_dev_id = GetVarDeviceID(input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { - op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); - for (auto &varname : op.InputArgumentNames()) { + op_dev_id = GetAppropriateDeviceID(input_var_names); + for (auto &varname : input_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } - for (auto &varname : op.OutputArgumentNames()) { + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } - } else if (op.Type() == "concat") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); - for (auto &varname : op.OutputArgumentNames()) { + } else if (node->Op()->Type() == "concat") { + op_dev_id = GetVarDeviceID(input_var_names[0]); + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } else { @@ -565,35 +592,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, } PADDLE_ENFORCE(op_dev_id != -1, - "can not find right place for distributed op: %s", op.Type()); + "can not find right place for distributed op: %s", + node->Op()->Type()); - CreateComputationalOp(result, op, op_dev_id); - if (op.Type() == "concat") { + CreateComputationalOp(result, node, op_dev_id); + if (node->Op()->Type() == "concat") { ConnectOp(result, result->Get("ops").back().get(), "fetch_barrier"); } } // Create RPC related op handles that connects its in ops and out ops. -void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { int op_dev_id = -1; - if (op.Type() == "send") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + if (node->Op()->Type() == "send") { + op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name()); // the variable name which contains .block means it was splited by // split_byref op // so that we can balance the variable blocks to all the pserver // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && - op.InputArgumentNames()[0].find(".block") == std::string::npos) { - op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); - for (auto &varname : op.InputArgumentNames()) { + node->inputs[0]->Var()->Name().find(".block") == std::string::npos) { + std::vector input_var_names; + for (ir::Node *n : node->inputs) { + input_var_names.push_back(n->Var()->Name()); + } + op_dev_id = GetAppropriateDeviceID(input_var_names); + for (auto &varname : input_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } - } else if (op.Type() == "recv") { - op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); - for (auto &varname : op.OutputArgumentNames()) { + } else if (node->Op()->Type() == "recv") { + std::vector output_var_names; + for (ir::Node *n : node->outputs) { + output_var_names.push_back(n->Var()->Name()); + } + op_dev_id = GetAppropriateDeviceID(output_var_names); + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } else { @@ -602,21 +637,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, } PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", - op.Type()); + node->Op()->Type()); - result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); - result->Get("ops").emplace_back( - new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id], - op.Type(), places_[op_dev_id])); + result->Get("ops").emplace_back(new RPCOpHandle( + result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], + node->Op()->Type(), places_[op_dev_id])); - if (op.Type() == "send_barrier") { + if (node->Op()->Type() == "send_barrier") { ConnectOp(result, result->Get("ops").back().get(), "send"); - } else if (op.Type() == "recv") { + } else if (node->Op()->Type() == "recv") { ConnectOp(result, result->Get("ops").back().get(), "send_barrier"); - } else if (op.Type() == "fetch_barrier") { + } else if (node->Op()->Type() == "fetch_barrier") { ConnectOp(result, result->Get("ops").back().get(), "recv"); - } else if (op.Type() == "send") { + } else if (node->Op()->Type() == "send") { // do nothing } else { PADDLE_THROW( @@ -624,12 +658,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, "send, send_barrier. recv, fetch_barrier]"); } - CreateOpHandleIOs(result, op, op_dev_id); + CreateOpHandleIOs(result, node, op_dev_id); } -bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { +bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { return boost::get( - op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss)) && !loss_var_name_.empty(); // If loss_var is empty. This is test mode diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 248ea8ea62..2b7f4f586b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -46,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &local_scopes, const BuildStrategy &strategy); #endif - - std::unique_ptr Build(std::unique_ptr graph) const override; + std::unique_ptr Apply(std::unique_ptr graph) const override; int GetVarDeviceID(const std::string &varname) const override; private: - void CreateOpHandleIOs(Graph *result, const OpDesc &op, - size_t device_id) const; + void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const; private: std::string loss_var_name_; @@ -64,40 +62,39 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { platform::NCCLContextMap *nccl_ctxs_; #endif - bool IsScaleLossOp(const OpDesc &op) const; + bool IsScaleLossOp(ir::Node *node) const; - void CreateRPCOp(Graph *result, const OpDesc &op) const; - void CreateDistTrainOp(Graph *result, const OpDesc &op) const; + void CreateRPCOp(Graph *result, ir::Node *node) const; + void CreateDistTrainOp(Graph *result, ir::Node *node) const; /** * Is this operator as the end-point operator before/after send operator. */ - bool IsDistTrainOp(const OpDesc &op, - const std::vector &send_vars, + bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, const std::vector &recv_vars) const; std::vector FindDistTrainSendVars( - const ProgramDesc &program) const; + const std::vector> &nodes) const; std::vector FindDistTrainRecvVars( - const ProgramDesc &program) const; + const std::vector> &nodes) const; void ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const; - void CreateComputationalOps(Graph *result, const OpDesc &op, + void CreateComputationalOps(Graph *result, ir::Node *node, size_t num_places) const; void CreateScaleLossGradOp(Graph *result) const; VarHandle *CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const; - void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const; + void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const; bool IsParameterGradientOnce( const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID(const OpDesc &op) const; + int GetOpDeviceID(ir::Node *node) const; void InsertAllReduceOp(Graph *result, const std::string &og) const; diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index d029dd9e15..e7c83ffd32 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -97,7 +97,7 @@ struct TestReduceOpHandle { } param_scopes_[out_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + nodes.emplace_back(new ir::Node()); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, @@ -121,7 +121,7 @@ struct TestReduceOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto *in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); in_var_handle->ClearGeneratedOp(); @@ -137,7 +137,7 @@ struct TestReduceOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + nodes.emplace_back(new ir::Node()); auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 846f98ddfa..6a8bd7875c 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -37,8 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto *dep_var = new DummyVarHandle(graph->nodes.back().get()); + auto *dep_var = new DummyVarHandle(graph->CreateVarNode("dummy")); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -49,15 +48,14 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { } VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - Graph *graph, const std::string &each_var_name, - const platform::Place &place, size_t place_offset) { + Graph *graph, ir::Node *node, const platform::Place &place, + size_t place_offset) { auto &var_holders = graph->Get("vars")[place_offset]; - auto &var_holder = var_holders[each_var_name]; + auto &var_holder = var_holders[node->Var()->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - var = new VarHandle(graph->nodes.back().get(), 0, place_offset, - each_var_name, place); + var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, + node->Var()->Name(), place); var_holder.emplace_back(var); } else { var = var_holder.rbegin()->get(); @@ -66,14 +64,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( } void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, + ir::Node *node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][each_var_name]; + auto &vars = graph->Get("vars")[place_offset][node->Var()->Name()]; size_t version = vars.size(); - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto var = new VarHandle(graph->nodes.back().get(), version, place_offset, - each_var_name, place); + auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, + place_offset, node->Var()->Name(), place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -85,8 +82,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); - auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get()); + auto *dummy_leaf = new DummyVarHandle(graph->CreateVarNode("dummy")); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 4fbf036241..9933bf32b7 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -23,6 +23,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { @@ -34,11 +35,11 @@ typedef std::vector< typedef std::unordered_set> GraphDepVars; typedef std::vector> GraphOps; -class SSAGraphBuilder { +class SSAGraphBuilder : public ir::Pass { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual std::unique_ptr Build(std::unique_ptr graph) const = 0; + virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); @@ -53,16 +54,15 @@ class SSAGraphBuilder { */ static void PolishGraphToSupportDataHazards(Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, - const std::string &each_var_name, + static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, - const platform::Place &place, size_t place_offset); + ir::Node *node, const platform::Place &place, + size_t place_offset); static void AddOutputToLeafOps(Graph *graph); }; diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 2c8b2e13c5..f108061038 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -28,10 +28,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Build(std::unique_ptr graph) const override { - auto new_graph = builder_->Build(std::move(graph)); + std::unique_ptr Apply(std::unique_ptr graph) const override { + auto new_graph = builder_->Apply(std::move(graph)); PADDLE_ENFORCE(IsValidGraph(new_graph.get())); - return new_graph; + return std::move(new_graph); } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 35f2a1b4f0..411be02988 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Build(std::unique_ptr graph) const override { - auto new_graph = builder_->Build(std::move(graph)); + std::unique_ptr Apply(std::unique_ptr graph) const override { + auto new_graph = builder_->Apply(std::move(graph)); printer_->Print(*new_graph, stream_ref_); - return new_graph; + return std::move(new_graph); } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 28ad4efc71..c1f8f917c4 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -13,12 +13,45 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" namespace paddle { namespace framework { std::unique_ptr ProgramToGraph(const ProgramDesc &program) { std::unique_ptr graph(new Graph(program)); + + std::unordered_map all_vars; + for (auto *var : program.Block(0).AllVars()) { + all_vars.emplace(var->Name(), var); + } + + for (auto *op : program.Block(0).AllOps()) { + ir::Node *node = graph->CreateOpNode(op); + + for (auto &each_var_name : op->InputArgumentNames()) { + ir::Node *var = nullptr; + if (all_vars.count(each_var_name) != 0) { + var = graph->CreateVarNode(all_vars.at(each_var_name)); + } else { + var = graph->CreateVarNode(each_var_name); + } + node->inputs.push_back(var); + var->outputs.push_back(node); + } + + for (auto &each_var_name : op->OutputArgumentNames()) { + ir::Node *var = nullptr; + if (all_vars.count(each_var_name) != 0) { + var = graph->CreateVarNode(all_vars.at(each_var_name)); + } else { + var = graph->CreateVarNode(each_var_name); + } + node->outputs.push_back(var); + var->inputs.push_back(node); + } + } return std::move(graph); } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index e83cb5a82a..ff4f31fb7a 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -39,8 +39,6 @@ class Graph { attr_dels_.clear(); } - const ProgramDesc& Program() const { return program_; } - template AttrType& Get(const std::string& attr_name) const { return *boost::any_cast(attrs_.at(attr_name)); @@ -63,11 +61,30 @@ class Graph { return attr; } + ir::Node* CreateVarNode(VarDesc* var_desc) { + nodes.emplace_back(new ir::Node(var_desc)); + return nodes.back().get(); + } + + ir::Node* CreateOpNode(OpDesc* op_desc) { + nodes.emplace_back(new ir::Node(op_desc)); + return nodes.back().get(); + } + + // TODO(panyx0718): Need to handle CreateOpNode(nullptr). + ir::Node* CreateVarNode(const std::string& var_name) { + var_descs_.emplace_back(new VarDesc(var_name)); + nodes.emplace_back(new ir::Node(var_descs_.back().get())); + return nodes.back().get(); + } + std::vector inputs; std::vector outputs; std::vector> nodes; + std::vector> var_descs_; private: + // NOTE: program_ shouldn't be exposed to user. const ProgramDesc& program_; std::map attrs_; std::map> attr_dels_; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 94ace92953..0e0b81a7b1 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -21,6 +21,8 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/variant.h" @@ -32,10 +34,12 @@ class Node { public: enum class Type { kNone = -1, kOperation, kVariable }; + Node() : type_(Type::kNone) {} + explicit Node(Type type) : type_(type) {} virtual ~Node() { - for (auto &attr : attrs_) { + for (auto& attr : attrs_) { if (attr_dels_.find(attr.first) != attr_dels_.end()) { attr_dels_[attr.first](); } @@ -47,23 +51,34 @@ class Node { Type NodeType() const { return type_; } template - void Set(const std::string &name, AttrType attr) { + void Set(const std::string& name, AttrType attr) { attrs_[name] = attr; } template - void Set(const std::string &name, AttrType *attr, + void Set(const std::string& name, AttrType* attr, std::function attr_del) { attrs_[name] = attr; attr_dels_[name] = attr_del; } - std::vector inputs; - std::vector outputs; + VarDesc* Var() { return var_desc_; } + OpDesc* Op() { return op_desc_; } + + explicit Node(VarDesc* var_desc) + : var_desc_(var_desc), op_desc_(nullptr), type_(Type::kVariable) {} + + explicit Node(OpDesc* op_desc) + : var_desc_(nullptr), op_desc_(op_desc), type_(Type::kOperation) {} + + std::vector inputs; + std::vector outputs; protected: std::map attrs_; std::map> attr_dels_; + VarDesc* var_desc_; + OpDesc* op_desc_; Type type_; private: diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 2fc26c053f..3f0fcff857 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -20,15 +20,15 @@ limitations under the License. */ namespace paddle { namespace framework { +namespace ir { class Pass { public: Pass() = default; virtual ~Pass() {} - virtual std::unique_ptr Apply(std::unique_ptr graph) { - return std::move(graph); - } -}; + virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; +}; +} // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index d30aba07a0..c9014ffdf5 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -131,13 +131,10 @@ ParallelExecutor::ParallelExecutor( PADDLE_THROW("Not compiled with CUDA."); #endif } - builder_ = builder_factory.Create(); - std::unique_ptr graph = builder_->Build(ProgramToGraph(main_program)); - + std::unique_ptr graph = builder_->Apply(ProgramToGraph(main_program)); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); - member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 10028a8c6e..59789c6def 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -148,6 +148,7 @@ class ParallelExecutor(object): lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW, main.list_vars()) ] + sys.stderr.write('!!!!!!!!before\n') self.executor = core.ParallelExecutor( self._places, @@ -158,6 +159,7 @@ class ParallelExecutor(object): set(self.persistable_vars), main.desc, loss_name if loss_name else '', scope, local_scopes, exec_strategy, build_strategy, num_trainers, trainer_id) + sys.stderr.write('!!!!!!!!after\n') self.scope = scope def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): From 10786a243ebe33e425a6202bd541a180bc17c510 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 13 Jul 2018 19:29:43 +0800 Subject: [PATCH 12/22] polish graph --- .../details/broadcast_op_handle_test.cc | 10 +-- .../details/computation_op_handle.cc | 6 +- .../framework/details/computation_op_handle.h | 3 +- .../details/gather_op_handle_test.cc | 10 +-- .../details/multi_devices_graph_builder.cc | 83 ++++++++++--------- .../details/reduce_op_handle_test.cc | 6 +- .../framework/details/ssa_graph_builder.cc | 19 +++-- .../details/threaded_ssa_graph_executor.cc | 4 +- paddle/fluid/framework/ir/graph.cc | 10 +-- paddle/fluid/framework/ir/graph.h | 6 +- paddle/fluid/framework/ir/node.h | 58 ++++++------- python/paddle/fluid/parallel_executor.py | 2 - 12 files changed, 104 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 1609b5965c..63a6ed9082 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,7 +96,7 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n(new ir::Node()); + std::unique_ptr n(new ir::Node("node0")); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -114,7 +114,7 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v(new ir::Node()); + std::unique_ptr v(new ir::Node("node1")); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -122,7 +122,7 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2(new ir::Node()); + std::unique_ptr v2(new ir::Node("node2")); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -133,7 +133,7 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3(new ir::Node()); + std::unique_ptr v3(new ir::Node("node3")); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -141,7 +141,7 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4(new ir::Node()); + std::unique_ptr v4(new ir::Node("node4")); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 16ad30d491..b6282debdb 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -19,10 +19,10 @@ namespace paddle { namespace framework { namespace details { -ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, - Scope *scope, platform::Place place) +ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, + platform::Place place) : OpHandleBase(node), - op_(framework::OpRegistry::CreateOp(op_desc)), + op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), place_(place) {} diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 9ca1d927b8..d9fcd92427 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,8 +28,7 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope, - platform::Place place); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index f80cabf501..e3806ac5e1 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,13 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node")); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node1")); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +96,7 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node2")); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +104,14 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node3")); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node4")); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index d66bc40090..035fb629a8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -90,7 +90,7 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( // since parameters are all in block 0, // it's enough to only scan send ops in block 0 for (auto &node : nodes) { - if (!node->Op()) continue; + if (node->NodeType() != ir::Node::Type::kOperation) continue; OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find send op, // instead of the the hard code string @@ -108,7 +108,7 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( const std::vector> &nodes) const { std::vector recv_vars; for (auto &node : nodes) { - if (!node->Op()) continue; + if (node->NodeType() != ir::Node::Type::kOperation) continue; OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find recv op, // instead of the hard code string @@ -149,10 +149,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( std::vector input_var_names; std::vector output_var_names; for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Var()->Name()); + input_var_names.push_back(input->Name()); } for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Var()->Name()); + output_var_names.push_back(output->Name()); } return checker(output_var_names, send_vars) || @@ -181,13 +181,13 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unique_ptr graph) const { + // Rebuild the graph structure. auto nodes = std::move(graph->nodes); graph->nodes.clear(); - LOG(ERROR) << "origin nodes count " << nodes.size(); for (auto &node : nodes) { - if (node->Var()) { - all_vars_.emplace(node->Var()->Name(), node->Var()); + if (node->NodeType() == ir::Node::Type::kVariable) { + all_vars_.emplace(node->Name(), node->Var()); } } @@ -212,7 +212,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // TODO(panyx0718): FIXME: nodes should be sorted by "program" order. for (auto &node : nodes) { - if (!node->Op()) continue; + if (node->NodeType() != ir::Node::Type::kOperation) continue; if (boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { @@ -235,7 +235,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node.get(), op_dev_id); for (ir::Node *n : node->outputs) { - var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id); + var_name_on_devices_.emplace(n->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's @@ -351,10 +351,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), + auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr), + auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); @@ -367,8 +367,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, auto &p = places_[i]; SetCommunicationContext(op_handle, p); auto &vars = result->Get("vars").at(i).at(p_name); - auto *out_var = - new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p); + auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(), + i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } @@ -378,7 +378,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const { result->Get("ops").emplace_back( - new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(), + new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); } @@ -386,11 +386,12 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new AllReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_, + places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new AllReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_)); + result->CreateEmptyNode("allreduce"), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -402,7 +403,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p); + auto var = + new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -411,11 +413,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back( + new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_)); + result->CreateEmptyNode("data_balance"), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -425,7 +428,7 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i, + auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); @@ -455,12 +458,12 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { return -1; } auto param_grad = boost::get>( - node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); int dev_id = GetVarDeviceID(param_grad[1]); - PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(), - param_grad[0]); + PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", + node->Op()->Type(), param_grad[0]); return dev_id; } @@ -481,8 +484,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif auto *op_handle = new ScaleLossGradOpHandle( - result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i], - places_[i], communication_dev_ctx); + result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(), + local_scopes_[i], places_[i], communication_dev_ctx); result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -495,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { const std::string grad_var_name = GradVarName(loss_var_name_); auto &vars = result->Get("vars")[i][grad_var_name]; size_t version = vars.size(); - auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i, + auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i, grad_var_name, places_[i]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -508,8 +511,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get("ops").emplace_back(new ComputationOpHandle( - result->CreateOpNode(node->Op()), *node->Op(), s, p)); + result->Get("ops").emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } } @@ -519,10 +522,10 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_)); + result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateOpNode(nullptr), local_scopes_, places_)); + result->CreateEmptyNode("reduce"), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -535,7 +538,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id, + auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -548,7 +551,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy")); + auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy")); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); @@ -562,10 +565,10 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, std::vector input_var_names; std::vector output_var_names; for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Var()->Name()); + input_var_names.push_back(input->Name()); } for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Var()->Name()); + output_var_names.push_back(output->Name()); } if (node->Op()->Type() == "split_byref" || @@ -606,16 +609,16 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { int op_dev_id = -1; if (node->Op()->Type() == "send") { - op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name()); + op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); // the variable name which contains .block means it was splited by // split_byref op // so that we can balance the variable blocks to all the pserver // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && - node->inputs[0]->Var()->Name().find(".block") == std::string::npos) { + node->inputs[0]->Name().find(".block") == std::string::npos) { std::vector input_var_names; for (ir::Node *n : node->inputs) { - input_var_names.push_back(n->Var()->Name()); + input_var_names.push_back(n->Name()); } op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { @@ -625,7 +628,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { } else if (node->Op()->Type() == "recv") { std::vector output_var_names; for (ir::Node *n : node->outputs) { - output_var_names.push_back(n->Var()->Name()); + output_var_names.push_back(n->Name()); } op_dev_id = GetAppropriateDeviceID(output_var_names); for (auto &varname : output_var_names) { diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index e7c83ffd32..3a9a584123 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -97,7 +97,7 @@ struct TestReduceOpHandle { } param_scopes_[out_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node")); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, @@ -121,7 +121,7 @@ struct TestReduceOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node1")); auto *in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); in_var_handle->ClearGeneratedOp(); @@ -137,7 +137,7 @@ struct TestReduceOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node()); + nodes.emplace_back(new ir::Node("node2")); auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 6a8bd7875c..884fc64555 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -37,7 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - auto *dep_var = new DummyVarHandle(graph->CreateVarNode("dummy")); + auto *dep_var = new DummyVarHandle(graph->CreateEmptyNode("dummy")); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -51,11 +51,16 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { auto &var_holders = graph->Get("vars")[place_offset]; - auto &var_holder = var_holders[node->Var()->Name()]; + auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { - var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, - node->Var()->Name(), place); + if (node->NodeType() == ir::Node::Type::kVariable) { + var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, + node->Name(), place); + } else { + var = new VarHandle(graph->CreateEmptyNode(node->Name()), 0, place_offset, + node->Name(), place); + } var_holder.emplace_back(var); } else { var = var_holder.rbegin()->get(); @@ -67,10 +72,10 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, ir::Node *node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][node->Var()->Name()]; + auto &vars = graph->Get("vars")[place_offset][node->Name()]; size_t version = vars.size(); auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, - place_offset, node->Var()->Name(), place); + place_offset, node->Name(), place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -82,7 +87,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle(graph->CreateVarNode("dummy")); + auto *dummy_leaf = new DummyVarHandle(graph->CreateEmptyNode("dummy")); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 9a2413118e..8c9cb7cabb 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - ir::Node *fetch_n = new ir::Node(ir::Node::Type::kOperation); + ir::Node *fetch_n = new ir::Node("fetch"); auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); temp_nodes->emplace_back(fetch_n); fetch_ops->emplace_back(op); @@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - ir::Node *dummy_n = new ir::Node(ir::Node::Type::kVariable); + ir::Node *dummy_n = new ir::Node("fetch"); auto *fetch_dummy = new DummyVarHandle(dummy_n); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index c1f8f917c4..14d697c509 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -35,19 +35,15 @@ std::unique_ptr ProgramToGraph(const ProgramDesc &program) { if (all_vars.count(each_var_name) != 0) { var = graph->CreateVarNode(all_vars.at(each_var_name)); } else { - var = graph->CreateVarNode(each_var_name); + LOG(ERROR) << "input var not in all_var list: " << each_var_name; + var = graph->CreateEmptyNode(each_var_name); } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = nullptr; - if (all_vars.count(each_var_name) != 0) { - var = graph->CreateVarNode(all_vars.at(each_var_name)); - } else { - var = graph->CreateVarNode(each_var_name); - } + ir::Node *var = graph->CreateVarNode(all_vars.at(each_var_name)); node->outputs.push_back(var); var->inputs.push_back(node); } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ff4f31fb7a..8b185f9625 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -72,16 +72,14 @@ class Graph { } // TODO(panyx0718): Need to handle CreateOpNode(nullptr). - ir::Node* CreateVarNode(const std::string& var_name) { - var_descs_.emplace_back(new VarDesc(var_name)); - nodes.emplace_back(new ir::Node(var_descs_.back().get())); + ir::Node* CreateEmptyNode(const std::string& name) { + nodes.emplace_back(new ir::Node(name)); return nodes.back().get(); } std::vector inputs; std::vector outputs; std::vector> nodes; - std::vector> var_descs_; private: // NOTE: program_ shouldn't be exposed to user. diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 0e0b81a7b1..d2d08bc461 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -32,51 +32,43 @@ namespace ir { class Node { public: - enum class Type { kNone = -1, kOperation, kVariable }; + enum class Type { kNone, kOperation, kVariable }; + explicit Node(const std::string& name) + : name_(name), + var_desc_(nullptr), + op_desc_(nullptr), + type_(Type::kNone) {} - Node() : type_(Type::kNone) {} - - explicit Node(Type type) : type_(type) {} + explicit Node(VarDesc* var_desc) + : name_(var_desc->Name()), + var_desc_(var_desc), + op_desc_(nullptr), + type_(Type::kVariable) {} - virtual ~Node() { - for (auto& attr : attrs_) { - if (attr_dels_.find(attr.first) != attr_dels_.end()) { - attr_dels_[attr.first](); - } - } - attr_dels_.clear(); - attrs_.clear(); - } + explicit Node(OpDesc* op_desc) + : name_(op_desc->Type()), + var_desc_(nullptr), + op_desc_(op_desc), + type_(Type::kOperation) {} Type NodeType() const { return type_; } - template - void Set(const std::string& name, AttrType attr) { - attrs_[name] = attr; - } + std::string Name() const { return name_; } - template - void Set(const std::string& name, AttrType* attr, - std::function attr_del) { - attrs_[name] = attr; - attr_dels_[name] = attr_del; + VarDesc* Var() { + PADDLE_ENFORCE(type_ == Type::kVariable); + return var_desc_; + } + OpDesc* Op() { + PADDLE_ENFORCE(type_ == Type::kOperation); + return op_desc_; } - - VarDesc* Var() { return var_desc_; } - OpDesc* Op() { return op_desc_; } - - explicit Node(VarDesc* var_desc) - : var_desc_(var_desc), op_desc_(nullptr), type_(Type::kVariable) {} - - explicit Node(OpDesc* op_desc) - : var_desc_(nullptr), op_desc_(op_desc), type_(Type::kOperation) {} std::vector inputs; std::vector outputs; protected: - std::map attrs_; - std::map> attr_dels_; + const std::string name_; VarDesc* var_desc_; OpDesc* op_desc_; Type type_; diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 59789c6def..10028a8c6e 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -148,7 +148,6 @@ class ParallelExecutor(object): lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW, main.list_vars()) ] - sys.stderr.write('!!!!!!!!before\n') self.executor = core.ParallelExecutor( self._places, @@ -159,7 +158,6 @@ class ParallelExecutor(object): set(self.persistable_vars), main.desc, loss_name if loss_name else '', scope, local_scopes, exec_strategy, build_strategy, num_trainers, trainer_id) - sys.stderr.write('!!!!!!!!after\n') self.scope = scope def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): From 64eaa4c82926e6864bdb3b0868308bcc8a9fded7 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 15 Jul 2018 17:08:56 +0800 Subject: [PATCH 13/22] clean --- paddle/fluid/framework/details/CMakeLists.txt | 5 +- .../fluid/framework/details/op_handle_base.h | 2 + .../scope_buffered_ssa_graph_executor.h | 3 ++ paddle/fluid/framework/details/ssa_graph.cc | 15 ------ paddle/fluid/framework/details/ssa_graph.h | 49 ------------------- .../framework/details/ssa_graph_builder.cc | 4 +- .../framework/details/ssa_graph_builder.h | 14 +++++- .../framework/details/ssa_graph_checker.cc | 4 +- .../framework/details/ssa_graph_executor.h | 2 +- .../framework/details/ssa_graph_printer.cc | 2 +- .../details/threaded_ssa_graph_executor.cc | 11 ++--- paddle/fluid/framework/details/var_handle.h | 3 ++ paddle/fluid/framework/ir/graph.cc | 18 +++---- paddle/fluid/framework/ir/graph.h | 19 ++----- paddle/fluid/framework/ir/node.h | 6 --- paddle/fluid/framework/parallel_executor.cc | 5 +- 16 files changed, 49 insertions(+), 113 deletions(-) delete mode 100644 paddle/fluid/framework/details/ssa_graph.cc delete mode 100644 paddle/fluid/framework/details/ssa_graph.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4fb4ec38ee..e8057c35e8 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) -cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) +cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) @@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) -cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) +cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 368a153711..2d7f189428 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -27,6 +27,8 @@ namespace details { constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; +// Wraps ir::Node and provide helper utilities. +// It's responsible for populating necessary fields of ir::Node. class OpHandleBase { public: explicit OpHandleBase(ir::Node *node) : node_(node) {} diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index 20df7a4722..cbfbcb1c0c 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -17,6 +17,9 @@ #include #include #include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/scope.h" diff --git a/paddle/fluid/framework/details/ssa_graph.cc b/paddle/fluid/framework/details/ssa_graph.cc deleted file mode 100644 index 1b8c889449..0000000000 --- a/paddle/fluid/framework/details/ssa_graph.cc +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/ssa_graph.h" diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h deleted file mode 100644 index e996a00c16..0000000000 --- a/paddle/fluid/framework/details/ssa_graph.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/details/op_handle_base.h" -#include "paddle/fluid/framework/details/var_handle.h" - -namespace paddle { -namespace framework { -namespace details { - -// A SSA graph used by parallel executor. -struct SSAGraph { - // all variable in each devices. - // The outside vector is the device vector. Each element of this vector is a - // map from variable name to variables. The variables, who have the same name, - // will have a different version. The offset in the - // `std::vector>` is the version of varaibles. - std::vector< - std::unordered_map>>> - vars_; - - // aux variables to represent dependency. Useful to resolve data hazard. - std::unordered_set> dep_vars_; - - // all operators. NOTE that even we use a vector here, the operators is - // unordered. - std::vector> ops_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 884fc64555..7de4426de8 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -81,9 +81,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, } void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { - GraphOps &all_ops = graph->Get("ops"); - - for (auto &op : all_ops) { + for (auto &op : graph->Get("ops")) { if (!op->Outputs().empty()) { continue; } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 9933bf32b7..87749009ef 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -18,7 +18,9 @@ #include #include -#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" @@ -29,10 +31,20 @@ namespace paddle { namespace framework { namespace details { +// all variable in each devices. +// The outside vector is the device vector. Each element of this vector is a +// map from variable name to variables. The variables, who have the same name, +// will have a differsent version. The offset in the +// `std::vector>` is the version of varaibles. typedef std::vector< std::unordered_map>>> GraphVars; + +// aux variables to represent dependency. Useful to resolve data hazard. typedef std::unordered_set> GraphDepVars; + +// all operators. NOTE that even we use a vector here, the operators is +// unordered. typedef std::vector> GraphOps; class SSAGraphBuilder : public ir::Pass { diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index 6a211f52bb..7c79d7f1e8 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph.h" -#include #include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 9580860336..8815ec89b2 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -18,8 +18,8 @@ #include #include -#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 412b0a6ff2..6dd6fd262e 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/framework/details/ssa_graph_printer.h" #include -#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 8c9cb7cabb..ac77788365 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -173,9 +173,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - ir::Node *fetch_n = new ir::Node("fetch"); - auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); - temp_nodes->emplace_back(fetch_n); + temp_nodes->emplace_back(new ir::Node("fetch")); + auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, + &local_scopes_); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -186,11 +186,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - ir::Node *dummy_n = new ir::Node("fetch"); - auto *fetch_dummy = new DummyVarHandle(dummy_n); + temp_nodes->emplace_back(new ir::Node("fetch")); + auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); - temp_nodes->emplace_back(dummy_n); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingOp(pending_ops, op); } diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 8bd3db9203..ae23e3b1f8 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -28,6 +28,9 @@ namespace framework { namespace details { class OpHandleBase; +// Wraps ir::Node and provide helper utilities. +// It's responsible for populating necessary fields of ir::Node. +// // VarHandleBase is the var node in the dependency graph. // A variable can only be generated by a single operator. i.e. // This is a single assignment graph. diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 14d697c509..1f6937658f 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -19,37 +19,35 @@ limitations under the License. */ namespace paddle { namespace framework { -std::unique_ptr ProgramToGraph(const ProgramDesc &program) { - std::unique_ptr graph(new Graph(program)); - +Graph::Graph(const ProgramDesc &program) : program_(program) { std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } for (auto *op : program.Block(0).AllOps()) { - ir::Node *node = graph->CreateOpNode(op); + ir::Node *node = CreateOpNode(op); for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; if (all_vars.count(each_var_name) != 0) { - var = graph->CreateVarNode(all_vars.at(each_var_name)); + var = CreateVarNode(all_vars.at(each_var_name)); } else { - LOG(ERROR) << "input var not in all_var list: " << each_var_name; - var = graph->CreateEmptyNode(each_var_name); + // TODO(paddle-dev): Seems some assumption doesn't hold? + LOG(ERROR) << op->Type() + << " input var not in all_var list: " << each_var_name; + var = CreateEmptyNode(each_var_name); } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = graph->CreateVarNode(all_vars.at(each_var_name)); + ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); node->outputs.push_back(var); var->inputs.push_back(node); } } - return std::move(graph); } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 8b185f9625..93db573559 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -29,7 +29,7 @@ namespace framework { class Graph { public: - explicit Graph(const ProgramDesc& program) : program_(program) {} + explicit Graph(const ProgramDesc& program); virtual ~Graph() { for (auto& attr : attrs_) { @@ -46,6 +46,7 @@ class Graph { template void Set(const std::string& attr_name, AttrType* attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(3) << "deleting " << attr_name; @@ -53,14 +54,6 @@ class Graph { }; } - template - AttrType* Erase(const std::string& attr_name) { - AttrType* attr = boost::any_cast(attrs_[attr_name]); - attrs_.erase(attr_name); - attr_dels_.erase(attr_name); - return attr; - } - ir::Node* CreateVarNode(VarDesc* var_desc) { nodes.emplace_back(new ir::Node(var_desc)); return nodes.back().get(); @@ -71,14 +64,14 @@ class Graph { return nodes.back().get(); } - // TODO(panyx0718): Need to handle CreateOpNode(nullptr). + // TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph. + // node should either be a executable kOperation or a kVariable. kNone + // node is a temporary solution. ir::Node* CreateEmptyNode(const std::string& name) { nodes.emplace_back(new ir::Node(name)); return nodes.back().get(); } - std::vector inputs; - std::vector outputs; std::vector> nodes; private: @@ -88,7 +81,5 @@ class Graph { std::map> attr_dels_; }; -std::unique_ptr ProgramToGraph(const ProgramDesc& program); - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d2d08bc461..cb1d524c34 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,17 +14,11 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include #include -#include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index c9014ffdf5..1e5bba62b5 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/ir/graph.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" @@ -132,7 +132,8 @@ ParallelExecutor::ParallelExecutor( #endif } builder_ = builder_factory.Create(); - std::unique_ptr graph = builder_->Apply(ProgramToGraph(main_program)); + std::unique_ptr graph(new Graph(main_program)); + graph = builder_->Apply(std::move(graph)); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( From 9c9e28b57ba96b60fe6289678710e36ff87cece4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 15 Jul 2018 22:07:57 +0800 Subject: [PATCH 14/22] fix program to graph --- .../details/multi_devices_graph_builder.cc | 5 ++++- paddle/fluid/framework/ir/graph.cc | 16 ++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 035fb629a8..1e7ec95342 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -210,7 +210,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( size_t cur_device_id = 0; bool is_forwarding = true; - // TODO(panyx0718): FIXME: nodes should be sorted by "program" order. + // NOTE: Currently, passes before SSAGraphBuilder cannot reorder + // forward, backward nodes. E.g. you can't append an forward node + // at the end of the node list. + // TODO(panyx0718): FIXME: Needs to sort by forward->backward order. for (auto &node : nodes) { if (node->NodeType() != ir::Node::Type::kOperation) continue; if (boost::get( diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 1f6937658f..f8381af985 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -19,31 +19,43 @@ limitations under the License. */ namespace paddle { namespace framework { +// NOTE(paddle-dev): This graph contains circle. Graph::Graph(const ProgramDesc &program) : program_(program) { std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } + std::map var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; - if (all_vars.count(each_var_name) != 0) { + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else if (all_vars.count(each_var_name) != 0) { var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; } else { // TODO(paddle-dev): Seems some assumption doesn't hold? LOG(ERROR) << op->Type() << " input var not in all_var list: " << each_var_name; var = CreateEmptyNode(each_var_name); + var_nodes[each_var_name] = var; } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); + ir::Node *var = nullptr; + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else { + var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; + } node->outputs.push_back(var); var->inputs.push_back(node); } From 2487951ba358823d06f1b9da3ca8ee7899dd048c Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 16 Jul 2018 09:32:46 +0800 Subject: [PATCH 15/22] add draft design doc --- doc/fluid/design/ir/draft.md | 81 ++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 doc/fluid/design/ir/draft.md diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md new file mode 100644 index 0000000000..ac788e06eb --- /dev/null +++ b/doc/fluid/design/ir/draft.md @@ -0,0 +1,81 @@ +## Motivation + +There is a ```gap``` between the ```Program``` defined by +user and the ```Executable``` that can be scheduled +efficiently on heterogeneous hardware, either locally +or distributedly. + +Usually, the ```gap``` is bridged by + +* A serious transformations with defined order. + +* The transformations usually invovle +```insert, delete, clustering, split, dependency analysis```. + +* Has a simple way to verify and debug each transformation. + +* Flexible to add, remove or customize transformations to fit +the requirements of various algorithms (models) and hardware secenarios. + +Some other events also push us to a better unified pattern. + +* The deep learning framework is built around the concepts of graphs. +To leverage tools such as compilation (e.g. TVM and nGraph) or +cross-framework conversion (e.g. ONNX), we also need a intermediate +representation that can be connected to the rest of the ecosystem. + + +We need a unified pattern to naturally support the requirements +described above. The pattern should fit both training, inference +and other offline serielized model transformations. +Learned from LLVM and other deep learning framework, we draft the +design below. + + +## Design + +### Major Concepts + +#### Node + +```Node``` represents an operation that performs some computation or +a variable that is input or output of operation. + +```Node```s are connected to other ```Node```s via inputs and outputs. + +#### Graph + +```Graph``` contains a list of ```Node```s. + +TODO: Better definitions for the graph. + +```Graph``` can also contain ```Attribute```s. ```Attribute```s +can be ``any`` thing. For example, it can be a list of "wraper" +nodes. The ```wrapper``` nodes compose ```Node```s and provide +helper method for execution. ```Attribute``` can also contain +other things that describe some properties of the ```Graph```. + +#### Pass + +```Pass``` represents a transformation of ```Graph```. Its input +is a ```Graph``` and its output is also a ```Graph```. For example, +a ```Pass``` can simply print out the ```Graph```. A ```Pass``` +can also fuse some ```Graph```'s ```Node```s. + +#### Optimize + +```Optimize``` contains a series of ```Pass``` with defined order. +```Optimize``` transforms a ```Graph``` that only contains raw +modeling logic to a ```Graph``` that can be run efficiently while +maintaining the original modeling logic. + + +### Workflow + + +* Program is first converted to Graph. +* Graph goes through a series of Pass +* Graph is transformed from raw model logic to a +form that is efficient to execute. + +Graph->Pass1->Graph->Pass2->Graph->Pass3->Executor From a323b26372c17221a19e4bc5483471c6e5e37df6 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 16 Jul 2018 15:10:36 +0800 Subject: [PATCH 16/22] clarify design --- doc/fluid/design/ir/draft.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index ac788e06eb..f5148c5cbf 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -43,17 +43,26 @@ a variable that is input or output of operation. ```Node```s are connected to other ```Node```s via inputs and outputs. +Other properties (maybe device placement information) can be added +to ```Node``` in the future if it's a +common requirement of many other ```Pass```es. Otherwise, it should live +in a ```Node``` wrapper class that is private to some ```Pass``` or be +a local member of a ```Pass```. + #### Graph -```Graph``` contains a list of ```Node```s. +```Graph``` contains a list of ```Node```s, which are connected to +each other via inputs and outputs. TODO: Better definitions for the graph. ```Graph``` can also contain ```Attribute```s. ```Attribute```s can be ``any`` thing. For example, it can be a list of "wraper" nodes. The ```wrapper``` nodes compose ```Node```s and provide -helper method for execution. ```Attribute``` can also contain -other things that describe some properties of the ```Graph```. +helper method for execution or transformation. ```Attribute``` +can also contain other things that describe some properties of +the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed +across ```Pass```. However, it should be used with care. #### Pass @@ -70,12 +79,11 @@ modeling logic to a ```Graph``` that can be run efficiently while maintaining the original modeling logic. -### Workflow - +### Optimize Process * Program is first converted to Graph. * Graph goes through a series of Pass * Graph is transformed from raw model logic to a form that is efficient to execute. -Graph->Pass1->Graph->Pass2->Graph->Pass3->Executor +Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Executor From 62e2aa115d4fdab425ad4e6c0e406fe44ebedc85 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 17 Jul 2018 09:51:34 +0800 Subject: [PATCH 17/22] add a graph_test --- paddle/fluid/framework/ir/CMakeLists.txt | 6 +- paddle/fluid/framework/ir/graph_test.cc | 112 +++++++++++++++++++++++ 2 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/ir/graph_test.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 4cd373e8ea..e8ed06aa69 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -1,3 +1,5 @@ -cc_library(graph SRCS graph.cc) +cc_library(graph SRCS graph.cc node) cc_library(node SRCS node.cc) -cc_library(pass SRCS pass.cc) +cc_library(pass SRCS pass.cc graph node) + +cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc new file mode 100644 index 0000000000..857188ef0a --- /dev/null +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/graph.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { + +class NOP : public OperatorBase { + public: + NOP(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const Scope &scope, + const platform::Place &place) const override {} +}; + +class SumOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SumOpVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDesc &op_desc, BlockDesc *block) const override { + auto &inputs = op_desc.Input("X"); + auto default_var_type = proto::VarType::SELECTED_ROWS; + + bool any_input_is_lod_tensor = std::any_of( + inputs.begin(), inputs.end(), [block](const std::string &name) { + return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; + }); + if (any_input_is_lod_tensor) { + default_var_type = proto::VarType::LOD_TENSOR; + } + + auto out_var_name = op_desc.Output("Out").front(); + block->Var(out_var_name)->SetType(default_var_type); + } +}; +} // namespace framework +} // namespace paddle + +REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker, + paddle::framework::SumOpVarTypeInference); +REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, + paddle::framework::SumOpMaker); + +namespace paddle { +namespace framework { + +TEST(GraphTest, Basic) { + ProgramDesc prog; + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_out"); + + op->InferVarType(prog.MutableBlock(0)); + + ASSERT_EQ(proto::VarType::SELECTED_ROWS, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::LOD_TENSOR, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + std::unique_ptr g(new Graph(prog)); + ASSERT_EQ(g->nodes[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a"); + ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b"); + ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c"); + ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out"); + ASSERT_EQ(g->nodes[1]->Name(), "test_a"); + ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[2]->Name(), "test_b"); + ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[3]->Name(), "test_c"); + ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[4]->Name(), "test_out"); + ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes.size(), 5); +} +} // namespace framework +} // namespace paddle From a891708d4bdae7632b4926ecd26f1cba25126488 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 17 Jul 2018 10:04:40 +0800 Subject: [PATCH 18/22] polish design --- doc/fluid/design/ir/draft.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index f5148c5cbf..ac40c2efde 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -86,4 +86,4 @@ maintaining the original modeling logic. * Graph is transformed from raw model logic to a form that is efficient to execute. -Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Executor +Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor From ff5a7b67ed4d6f7e6e44048f51d52ad7f83bb481 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 17 Jul 2018 19:58:39 +0800 Subject: [PATCH 19/22] polish --- .../details/broadcast_op_handle_test.cc | 15 ++-- .../details/gather_op_handle_test.cc | 10 +-- .../details/multi_devices_graph_builder.cc | 83 +++++++++++-------- .../framework/details/ssa_graph_builder.cc | 21 +++-- .../framework/details/ssa_graph_builder.h | 2 +- .../details/threaded_ssa_graph_executor.cc | 4 +- paddle/fluid/framework/ir/graph.cc | 2 +- paddle/fluid/framework/ir/graph.h | 4 +- paddle/fluid/framework/ir/node.h | 9 +- 9 files changed, 85 insertions(+), 65 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 63a6ed9082..1413f7bd9a 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,7 +96,8 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n(new ir::Node("node0")); + std::unique_ptr n( + new ir::Node("node0", ir::Node::Type::kOperation)); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -114,7 +115,8 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v(new ir::Node("node1")); + std::unique_ptr v( + new ir::Node("node1", ir::Node::Type::kVariable)); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -122,7 +124,8 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2(new ir::Node("node2")); + std::unique_ptr v2( + new ir::Node("node2", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -133,7 +136,8 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3(new ir::Node("node3")); + std::unique_ptr v3( + new ir::Node("node3", ir::Node::Type::kVariable)); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -141,7 +145,8 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4(new ir::Node("node4")); + std::unique_ptr v4( + new ir::Node("node4", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index e3806ac5e1..c9b94d1e10 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,13 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node("node")); + nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation)); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node("node1")); + nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable)); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +96,7 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node("node2")); + nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +104,14 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node("node3")); + nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable)); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node("node4")); + nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable)); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 1e7ec95342..c52980472d 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, } for (ir::Node *output : node->outputs) { - CreateOpOutput(result, op_handle, output, p, place_id); + ir::Node *new_node = nullptr; + if (output->Var()) { + new_node = result->CreateVarNode(output->Var()); + } else { + new_node = + result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); + } + CreateOpOutput(result, op_handle, new_node, p, place_id); } } @@ -246,7 +253,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { node->Op()->SetAttr("throw_eof_exp", false); CreateComputationalOps(&result, node.get(), places_.size()); - // TODO(panyx0718): builder shouldn't depend on the out logic of + // TODO(paddle-dev): builder shouldn't depend on the out logic of // a specific op. const auto &data_var_names = node->Op()->Output("Out"); InsertDataBalanceOp(&result, data_var_names); @@ -354,11 +361,13 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), - local_scopes_, places_, nccl_ctxs_); + auto *op_handle = new BroadcastOpHandle( + result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"), - local_scopes_, places_); + auto *op_handle = new BroadcastOpHandle( + result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), + local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); @@ -370,8 +379,9 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, auto &p = places_[i]; SetCommunicationContext(op_handle, p); auto &vars = result->Get("vars").at(i).at(p_name); - auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(), - i, p_name, p); + auto *out_var = new VarHandle( + result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), + i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } @@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_, - places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new AllReduceOpHandle( - result->CreateEmptyNode("allreduce"), local_scopes_, places_)); + result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); auto var = - new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p); + new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), + vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"), - local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new DataBalanceOpHandle( - result->CreateEmptyNode("data_balance"), local_scopes_, places_)); + result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i, - d_name, p); + auto var = new VarHandle( + result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), + vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif auto *op_handle = new ScaleLossGradOpHandle( - result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(), - local_scopes_[i], places_[i], communication_dev_ctx); + result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), + local_scopes_.size(), local_scopes_[i], places_[i], + communication_dev_ctx); result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - // TODO(panyx0718): GradVarName(loss_var_name_) - const std::string grad_var_name = GradVarName(loss_var_name_); - auto &vars = result->Get("vars")[i][grad_var_name]; - size_t version = vars.size(); - auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i, - grad_var_name, places_[i]); - vars.emplace_back(var); - op_handle->AddOutput(var); + CreateOpOutput(result, op_handle, + result->CreateEmptyNode(GradVarName(loss_var_name_), + ir::Node::Type::kVariable), + places_[i], i); } } @@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_)); + result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back(new ReduceOpHandle( - result->CreateEmptyNode("reduce"), local_scopes_, places_)); + result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id, - og, places_[dst_dev_id]); + auto var = + new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), + vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); return var; @@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy")); + auto *dep_var = new DummyVarHandle( + result->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 7de4426de8..7bc130ef6e 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - auto *dep_var = new DummyVarHandle(graph->CreateEmptyNode("dummy")); + auto *dep_var = new DummyVarHandle( + graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -54,12 +55,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { - if (node->NodeType() == ir::Node::Type::kVariable) { + if (node->Var()) { var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, node->Name(), place); } else { - var = new VarHandle(graph->CreateEmptyNode(node->Name()), 0, place_offset, - node->Name(), place); + var = new VarHandle( + graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0, + place_offset, node->Name(), place); } var_holder.emplace_back(var); } else { @@ -69,13 +71,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( } void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - ir::Node *node, + ir::Node *new_node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][node->Name()]; + auto &vars = graph->Get("vars")[place_offset][new_node->Name()]; size_t version = vars.size(); - auto var = new VarHandle(graph->CreateVarNode(node->Var()), version, - place_offset, node->Name(), place); + auto var = + new VarHandle(new_node, version, place_offset, new_node->Name(), place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -85,7 +87,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle(graph->CreateEmptyNode("dummy")); + auto *dummy_leaf = new DummyVarHandle( + graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 87749009ef..e8e8acdb38 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -73,7 +73,7 @@ class SSAGraphBuilder : public ir::Pass { // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, - ir::Node *node, const platform::Place &place, + ir::Node *new_node, const platform::Place &place, size_t place_offset); static void AddOutputToLeafOps(Graph *graph); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ac77788365..38cde13fe2 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - temp_nodes->emplace_back(new ir::Node("fetch")); + temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, &local_scopes_); fetch_ops->emplace_back(op); @@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - temp_nodes->emplace_back(new ir::Node("fetch")); + temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index f8381af985..d384ac0d50 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -41,7 +41,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { // TODO(paddle-dev): Seems some assumption doesn't hold? LOG(ERROR) << op->Type() << " input var not in all_var list: " << each_var_name; - var = CreateEmptyNode(each_var_name); + var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); var_nodes[each_var_name] = var; } node->inputs.push_back(var); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 93db573559..3c268682af 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -67,8 +67,8 @@ class Graph { // TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph. // node should either be a executable kOperation or a kVariable. kNone // node is a temporary solution. - ir::Node* CreateEmptyNode(const std::string& name) { - nodes.emplace_back(new ir::Node(name)); + ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) { + nodes.emplace_back(new ir::Node(name, type)); return nodes.back().get(); } diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index cb1d524c34..38080b4ec5 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -26,12 +26,9 @@ namespace ir { class Node { public: - enum class Type { kNone, kOperation, kVariable }; - explicit Node(const std::string& name) - : name_(name), - var_desc_(nullptr), - op_desc_(nullptr), - type_(Type::kNone) {} + enum class Type { kOperation, kVariable }; + explicit Node(const std::string& name, Type type) + : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} explicit Node(VarDesc* var_desc) : name_(var_desc->Name()), From 5daad162184e1e7927ae5763eea8807c5118cfac Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 17 Jul 2018 20:44:24 +0800 Subject: [PATCH 20/22] polish --- paddle/fluid/framework/details/var_handle.h | 3 ++- paddle/fluid/framework/ir/graph.h | 3 --- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index ae23e3b1f8..ba8b38d1e7 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -56,7 +56,8 @@ struct VarHandleBase { void RemoveOutput(OpHandleBase* out, ir::Node* node) { pending_ops_.erase(out); - std::remove(node_->outputs.begin(), node_->outputs.end(), node); + node_->outputs.erase( + std::remove(node_->outputs.begin(), node_->outputs.end(), node)); } void ClearGeneratedOp() { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 3c268682af..2ab018dd85 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -64,9 +64,6 @@ class Graph { return nodes.back().get(); } - // TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph. - // node should either be a executable kOperation or a kVariable. kNone - // node is a temporary solution. ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) { nodes.emplace_back(new ir::Node(name, type)); return nodes.back().get(); From da5efa735a3db297131421e9b7dab703f1d5d2ae Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 17 Jul 2018 21:34:00 +0800 Subject: [PATCH 21/22] fix --- paddle/fluid/framework/details/var_handle.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index ba8b38d1e7..d8c2bc40b9 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -57,7 +57,8 @@ struct VarHandleBase { void RemoveOutput(OpHandleBase* out, ir::Node* node) { pending_ops_.erase(out); node_->outputs.erase( - std::remove(node_->outputs.begin(), node_->outputs.end(), node)); + std::remove(node_->outputs.begin(), node_->outputs.end(), node), + node_->outputs.end()); } void ClearGeneratedOp() { From 950585f419ee286585d0b4f8e94bdda7bc5d8725 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 18 Jul 2018 09:23:32 +0800 Subject: [PATCH 22/22] follow comments --- doc/fluid/design/ir/draft.md | 2 +- paddle/fluid/framework/ir/graph.cc | 2 +- paddle/fluid/framework/ir/graph.h | 2 +- paddle/fluid/framework/ir/graph_test.cc | 2 +- paddle/fluid/framework/ir/node.cc | 2 +- paddle/fluid/framework/ir/node.h | 2 +- paddle/fluid/framework/ir/pass.cc | 2 +- paddle/fluid/framework/ir/pass.h | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index ac40c2efde..a141dcbca5 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -9,7 +9,7 @@ Usually, the ```gap``` is bridged by * A serious transformations with defined order. -* The transformations usually invovle +* These transformations usually involve ```insert, delete, clustering, split, dependency analysis```. * Has a simple way to verify and debug each transformation. diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index d384ac0d50..688f7ba582 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 2ab018dd85..b4ac135b02 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 857188ef0a..4e23bf124f 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index ca83fa7a83..86376e7e8b 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 38080b4ec5..b98c29b81d 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 91b0decd25..c05d7d0bb5 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 3f0fcff857..f52ba788d5 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.