From 626abfc33ac373b2c18ef1b26d0b1470eb8f94c0 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 9 Aug 2018 14:16:31 +0800 Subject: [PATCH] code clean up and renaming Reduce one level of inheritence. --- doc/fluid/design/ir/{draft.md => overview.md} | 4 +- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 8 +- ...r.cc => multi_devices_graph_check_pass.cc} | 4 +- ...ker.h => multi_devices_graph_check_pass.h} | 4 +- ...builder.cc => multi_devices_graph_pass.cc} | 90 ++++++++++++++- ...h_builder.h => multi_devices_graph_pass.h} | 4 +- ...r.cc => multi_devices_graph_print_pass.cc} | 4 +- ...ter.h => multi_devices_graph_print_pass.h} | 4 +- .../framework/details/multi_devices_helper.cc | 20 ++++ ...graph_builder.h => multi_devices_helper.h} | 27 ----- .../framework/details/ssa_graph_builder.cc | 107 ------------------ .../details/threaded_ssa_graph_executor.cc | 2 +- paddle/fluid/framework/parallel_executor.cc | 50 ++++---- paddle/fluid/framework/parallel_executor.h | 2 +- 15 files changed, 152 insertions(+), 180 deletions(-) rename doc/fluid/design/ir/{draft.md => overview.md} (97%) rename paddle/fluid/framework/details/{ssa_graph_checker.cc => multi_devices_graph_check_pass.cc} (95%) rename paddle/fluid/framework/details/{ssa_graph_checker.h => multi_devices_graph_check_pass.h} (89%) rename paddle/fluid/framework/details/{multi_devices_graph_builder.cc => multi_devices_graph_pass.cc} (90%) rename paddle/fluid/framework/details/{multi_devices_graph_builder.h => multi_devices_graph_pass.h} (96%) rename paddle/fluid/framework/details/{ssa_graph_printer.cc => multi_devices_graph_print_pass.cc} (95%) rename paddle/fluid/framework/details/{ssa_graph_printer.h => multi_devices_graph_print_pass.h} (92%) create mode 100644 paddle/fluid/framework/details/multi_devices_helper.cc rename paddle/fluid/framework/details/{ssa_graph_builder.h => multi_devices_helper.h} (68%) delete mode 100644 paddle/fluid/framework/details/ssa_graph_builder.cc diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/overview.md similarity index 97% rename from doc/fluid/design/ir/draft.md rename to doc/fluid/design/ir/overview.md index c29337cba1..83ef97c99e 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/overview.md @@ -177,8 +177,8 @@ graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah)); auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass"); mem_opt_pass.SetNotOwned("optimize_level", 1); mem_opt_pass->Apply(std::move(graph)); -graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah)); -graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah)); +graph = PassRegistry::Instance().Get("multi_devices_pass").Apply(std::move(grah)); +graph = PassRegistry::Instance().Get("multi_devices_check_pass").Apply(std::move(grah)); Executor exe; exe.Run(graph); diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6440607dbe..1d62792b80 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -100,7 +100,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 5d652d3730..8f6c4163d6 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,9 +5,9 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph graph_helper) -cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) -cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) +cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) +cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) +cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) @@ -28,7 +28,7 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle +cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc similarity index 95% rename from paddle/fluid/framework/details/ssa_graph_checker.cc rename to paddle/fluid/framework/details/multi_devices_graph_check_pass.cc index b9e1cda1f2..c9c255864a 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include #include "paddle/fluid/framework/ir/graph.h" @@ -86,7 +86,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(multi_device_check_pass, +REGISTER_PASS(multi_devices_check_pass, paddle::framework::details::SSAGraghBuilderWithChecker) .RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphDepVars) diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/multi_devices_graph_check_pass.h similarity index 89% rename from paddle/fluid/framework/details/ssa_graph_checker.h rename to paddle/fluid/framework/details/multi_devices_graph_check_pass.h index 0e861ecb23..1e2b1867c3 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include @@ -22,7 +22,7 @@ namespace paddle { namespace framework { namespace details { -class SSAGraghBuilderWithChecker : public SSAGraphBuilder { +class SSAGraghBuilderWithChecker : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc similarity index 90% rename from paddle/fluid/framework/details/multi_devices_graph_builder.cc rename to paddle/fluid/framework/details/multi_devices_graph_pass.cc index a4fdbcb26d..c5a13e7e1f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -21,7 +21,7 @@ #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h" -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" @@ -33,6 +33,92 @@ namespace paddle { namespace framework { namespace details { +namespace { +void PolishGraphToSupportDataHazards(ir::Graph *graph) { + for (auto &var_map : graph->Get(kGraphVars)) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + continue; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + OpHandleBase *write_op = (*it_new)->GeneratedOp(); + const auto &read_ops = (*it_old)->PendingOps(); + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + bool has_dep = false; + for (auto *r_out : read_op->Outputs()) { + for (auto *w_in : write_op->Inputs()) { + if (r_out->Node() == w_in->Node()) { + has_dep = true; + break; + } + } + } + if (has_dep) continue; + + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + } + } + } + } +} + +VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, + const platform::Place &place, + size_t place_offset) { + auto &var_holders = graph->Get(kGraphVars)[place_offset]; + auto &var_holder = var_holders[node->Name()]; + VarHandle *var = nullptr; + if (var_holder.empty()) { + if (node->Var()) { + var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, + node->Name(), place); + } else { + var = new VarHandle( + graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0, + place_offset, node->Name(), place); + } + var_holder.emplace_back(var); + } else { + var = var_holder.rbegin()->get(); + } + return var; +} + +void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, + ir::Node *new_node, const platform::Place &place, + size_t place_offset) { + auto &vars = + graph->Get(kGraphVars)[place_offset][new_node->Name()]; + size_t version = vars.size(); + auto var = + new VarHandle(new_node, version, place_offset, new_node->Name(), place); + vars.emplace_back(var); + op_handle->AddOutput(var); +} + +void AddOutputToLeafOps(ir::Graph *graph) { + for (auto &op : graph->Get(kGraphOps)) { + if (!op->Outputs().empty()) { + continue; + } + auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dummy_leaf); + op->AddOutput(dummy_leaf); + } +} +} // namespace static const char kLossVarName[] = "loss_var_name"; static const char kPlaces[] = "places"; @@ -751,7 +837,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { } // namespace framework } // namespace paddle -REGISTER_PASS(multi_device_pass, +REGISTER_PASS(multi_devices_pass, paddle::framework::details::MultiDevSSAGraphBuilder) .RequirePassAttr(paddle::framework::details::kLossVarName) .RequirePassAttr(paddle::framework::details::kPlaces) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h similarity index 96% rename from paddle/fluid/framework/details/multi_devices_graph_builder.h rename to paddle/fluid/framework/details/multi_devices_graph_pass.h index f2cb6bb1c8..7a6f238f9c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -18,7 +18,7 @@ #include #include "paddle/fluid/framework/details/build_strategy.h" -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph.h" namespace paddle { @@ -30,7 +30,7 @@ namespace framework { class Scope; namespace details { -class MultiDevSSAGraphBuilder : public SSAGraphBuilder { +class MultiDevSSAGraphBuilder : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc similarity index 95% rename from paddle/fluid/framework/details/ssa_graph_printer.cc rename to paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index ec3f31ab8d..69944a42b6 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph_printer.h" +#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include #include "paddle/fluid/framework/ir/graph.h" @@ -82,5 +82,5 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, } // namespace framework } // namespace paddle -REGISTER_PASS(multi_device_print_pass, +REGISTER_PASS(multi_devices_print_pass, paddle::framework::details::SSAGraghBuilderWithPrinter); diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h similarity index 92% rename from paddle/fluid/framework/details/ssa_graph_printer.h rename to paddle/fluid/framework/details/multi_devices_graph_print_pass.h index 5eafd1805c..c00685fa16 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h @@ -18,7 +18,7 @@ #include #include #include -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" namespace paddle { namespace framework { @@ -35,7 +35,7 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { void Print(const ir::Graph& graph, std::ostream& sout) const override; }; -class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { +class SSAGraghBuilderWithPrinter : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { diff --git a/paddle/fluid/framework/details/multi_devices_helper.cc b/paddle/fluid/framework/details/multi_devices_helper.cc new file mode 100644 index 0000000000..0242274a16 --- /dev/null +++ b/paddle/fluid/framework/details/multi_devices_helper.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/details/multi_devices_helper.h" + +namespace paddle { +namespace framework { +namespace details {} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/multi_devices_helper.h similarity index 68% rename from paddle/fluid/framework/details/ssa_graph_builder.h rename to paddle/fluid/framework/details/multi_devices_helper.h index 53a4ad003d..175c5a9950 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -52,33 +52,6 @@ const char kGraphOps[] = "ops"; typedef std::unordered_map ShardedVarDevice; const char kShardedVarDevice[] = "sharded_var_device"; - -class SSAGraphBuilder : public ir::Pass { - public: - SSAGraphBuilder() {} - virtual ~SSAGraphBuilder() {} - - DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); - - protected: - /* - Dependency graph has been constructed. However, there are still data - hazards need to be handled. - */ - static void PolishGraphToSupportDataHazards(ir::Graph *graph); - - static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, - const platform::Place &place, - size_t place_offset); - - // Add an output variable (each_var_name, place, place_offset) to op_handle, - // which belongs to graph - static void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, - ir::Node *new_node, const platform::Place &place, - size_t place_offset); - - static void AddOutputToLeafOps(ir::Graph *graph); -}; } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc deleted file mode 100644 index 575532540a..0000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph_builder.h" -#include - -namespace paddle { -namespace framework { -namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { - for (auto &var_map : graph->Get(kGraphVars)) { - for (auto &name_pair : var_map) { - if (name_pair.second.size() <= 1) { - continue; - } - auto it_new = name_pair.second.rbegin(); - auto it_old = name_pair.second.rbegin(); - ++it_old; - for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - OpHandleBase *write_op = (*it_new)->GeneratedOp(); - const auto &read_ops = (*it_old)->PendingOps(); - - for (auto *read_op : read_ops) { - // Manually add a dependency var from read_op to write_op; - if (read_op == write_op) { - // Read Write is the same op. - continue; - } - bool has_dep = false; - for (auto *r_out : read_op->Outputs()) { - for (auto *w_in : write_op->Inputs()) { - if (r_out->Node() == w_in->Node()) { - has_dep = true; - break; - } - } - } - if (has_dep) continue; - - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - read_op->AddOutput(dep_var); - write_op->AddInput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - } - } - } - } -} - -VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - ir::Graph *graph, ir::Node *node, const platform::Place &place, - size_t place_offset) { - auto &var_holders = graph->Get(kGraphVars)[place_offset]; - auto &var_holder = var_holders[node->Name()]; - VarHandle *var = nullptr; - if (var_holder.empty()) { - if (node->Var()) { - var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, - node->Name(), place); - } else { - var = new VarHandle( - graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0, - place_offset, node->Name(), place); - } - var_holder.emplace_back(var); - } else { - var = var_holder.rbegin()->get(); - } - return var; -} - -void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, - ir::Node *new_node, - const platform::Place &place, - size_t place_offset) { - auto &vars = - graph->Get(kGraphVars)[place_offset][new_node->Name()]; - size_t version = vars.size(); - auto var = - new VarHandle(new_node, version, place_offset, new_node->Name(), place); - vars.emplace_back(var); - op_handle->AddOutput(var); -} - -void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { - for (auto &op : graph->Get(kGraphOps)) { - if (!op->Outputs().empty()) { - continue; - } - auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get(kGraphDepVars).emplace(dummy_leaf); - op->AddOutput(dummy_leaf); - } -} -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 0eaf9a9c95..994bb6492f 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b5f01a9a2b..275cb8c592 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -25,9 +25,9 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif +#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" +#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/ssa_graph_checker.h" -#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -57,39 +57,39 @@ std::unique_ptr ApplyParallelExecutorPass( } // Convert graph to run on multi-devices. - auto multi_device_pass = - ir::PassRegistry::Instance().Get("multi_device_pass"); - multi_device_pass->SetNotOwned>("places", - &places); - multi_device_pass->SetNotOwned("loss_var_name", - &loss_var_name); - multi_device_pass->SetNotOwned>( + auto multi_devices_pass = + ir::PassRegistry::Instance().Get("multi_devices_pass"); + multi_devices_pass->SetNotOwned>("places", + &places); + multi_devices_pass->SetNotOwned("loss_var_name", + &loss_var_name); + multi_devices_pass->SetNotOwned>( "params", ¶m_names); - multi_device_pass->SetNotOwned>("local_scopes", - &local_scopes); - multi_device_pass->SetNotOwned("strategy", &strategy); + multi_devices_pass->SetNotOwned>("local_scopes", + &local_scopes); + multi_devices_pass->SetNotOwned("strategy", &strategy); #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; - multi_device_pass->SetNotOwned("nccl_ctxs", nctx); + multi_devices_pass->SetNotOwned("nccl_ctxs", nctx); #endif - graph = multi_device_pass->Apply(std::move(graph)); + graph = multi_devices_pass->Apply(std::move(graph)); // Apply a graph print pass to record a graph with device info. if (!strategy.debug_graphviz_path_.empty()) { - auto multi_device_print_pass = - ir::PassRegistry::Instance().Get("multi_device_print_pass"); - multi_device_print_pass->SetNotOwned( + auto multi_devices_print_pass = + ir::PassRegistry::Instance().Get("multi_devices_print_pass"); + multi_devices_print_pass->SetNotOwned( "debug_graphviz_path", &strategy.debug_graphviz_path_); - multi_device_print_pass->Set( + multi_devices_print_pass->Set( "graph_printer", new details::GraphvizSSAGraphPrinter); - graph = multi_device_print_pass->Apply(std::move(graph)); + graph = multi_devices_print_pass->Apply(std::move(graph)); } // Verify that the graph is correct for multi-device executor. - auto multi_device_check_pass = - ir::PassRegistry::Instance().Get("multi_device_check_pass"); - graph = multi_device_check_pass->Apply(std::move(graph)); + auto multi_devices_check_pass = + ir::PassRegistry::Instance().Get("multi_devices_check_pass"); + graph = multi_devices_check_pass->Apply(std::move(graph)); return graph; } @@ -354,6 +354,6 @@ ParallelExecutor::~ParallelExecutor() { } // namespace paddle USE_PASS(graph_viz_pass); -USE_PASS(multi_device_pass); -USE_PASS(multi_device_check_pass); -USE_PASS(multi_device_print_pass); +USE_PASS(multi_devices_pass); +USE_PASS(multi_devices_check_pass); +USE_PASS(multi_devices_print_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index d624956acd..5fb748fa20 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -19,7 +19,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/details/execution_strategy.h" -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h"