From 9ac785be396bd21d3f152a299f5fa7cb5e268e08 Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Thu, 7 Jun 2018 15:40:58 +0800 Subject: [PATCH 1/4] check graph's validation --- .../details/multi_devices_graph_builder.cc | 1 - .../framework/details/ssa_graph_builder.cc | 70 ++++++++++++++++++- .../framework/details/ssa_graph_builder.h | 3 + .../details/threaded_ssa_graph_executor.cc | 1 + 4 files changed, 73 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0c4d369e88..81d5b079b8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -272,7 +272,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - return std::unique_ptr<SSAGraph>(graph); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 211113c797..d70f95a9f5 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -11,8 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include <utility> namespace paddle { namespace framework { @@ -83,6 +83,74 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { op->AddOutput(dummy_leaf); } } + +std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck( + const ProgramDesc &program) final { + std::unique_ptr<SSAGraph> graph = Build(program); + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return std::move(graph); +} + +bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { + std::unordered_map<OpHandleBase *, size_t> pending_ops; + std::unordered_set<VarHandleBase *> pending_vars; + std::unordered_set<VarHandleBase *> ready_vars; + std::unordered_set<OpHandleBase *> ready_ops; + + auto insert_pending_var = [&](VarHandleBase *var) { + pending_vars.insert(var); + if (var->generated_op_ == nullptr) { + ready_vars.emplace(var); + } + }; + + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + insert_pending_var(version_pair.get()); + } + } + } + + for (auto &var : graph->dep_vars_) { + insert_pending_var(var.get()); + } + + for (auto &op : graph->ops_) { + if (op->Inputs().empty()) { + ready_ops.insert(op.get()); + } else { + pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); + } + } + + auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { + for (auto *op : set) { + for (auto out : op->Outputs()) { + ready_vars.emplace(out); + } + } + set.clear(); + }; + + while (!pending_vars.empty()) { + run_all_ops(ready_ops); + if (ready_vars.empty()) { + return false; + } + for (auto ready_var : ready_vars.) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = --pending_ops[op]; + if (deps == 0) { + ready_ops.insert(op); + } + } + } + ready_vars.clear(); + } + return true; +} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 5fc12a44b5..da9298ac8d 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -31,6 +31,8 @@ class SSAGraphBuilder { virtual ~SSAGraphBuilder() {} virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; + std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program) final; + DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: @@ -48,6 +50,7 @@ class SSAGraphBuilder { const platform::Place &place, size_t place_offset); + bool IsValidGraph(const SSAGraph *graph) const; // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 496fadd04d..bcbf573626 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ready_vars->Push(var); } } + void ThreadedSSAGraphExecutor::RunOp( BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { auto op_run = [ready_var_q, op, this] { From 8291b916d6cf053db779598b01dd59191fb5a1df Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Thu, 7 Jun 2018 16:24:23 +0800 Subject: [PATCH 2/4] replace graph_builder_factory with ssa_graph_builder_factory --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 2 +- paddle/fluid/framework/details/multi_devices_graph_builder.cc | 1 + paddle/fluid/framework/details/ssa_graph_builder.cc | 4 ++-- paddle/fluid/framework/details/ssa_graph_builder.h | 2 +- ...{graph_builder_factory.cc => ssa_graph_builder_factory.cc} | 2 +- .../{graph_builder_factory.h => ssa_graph_builder_factory.h} | 0 paddle/fluid/framework/parallel_executor.cc | 4 ++-- 8 files changed, 9 insertions(+), 8 deletions(-) rename paddle/fluid/framework/details/{graph_builder_factory.cc => ssa_graph_builder_factory.cc} (96%) rename paddle/fluid/framework/details/{graph_builder_factory.h => ssa_graph_builder_factory.h} (100%) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 627370cd2d..4271e4c1bb 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) -cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index c106761f72..ced063a097 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -30,7 +30,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) -cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) +cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 81d5b079b8..0c4d369e88 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -272,6 +272,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); + return std::unique_ptr<SSAGraph>(graph); } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index d70f95a9f5..d24669a8f8 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -85,7 +85,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { } std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck( - const ProgramDesc &program) final { + const ProgramDesc &program) { std::unique_ptr<SSAGraph> graph = Build(program); PADDLE_ENFORCE(IsValidGraph(graph.get())); return std::move(graph); @@ -138,7 +138,7 @@ bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { if (ready_vars.empty()) { return false; } - for (auto ready_var : ready_vars.) { + for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); for (auto *op : ready_var->pending_ops_) { auto &deps = --pending_ops[op]; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index da9298ac8d..e99a988407 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -31,7 +31,7 @@ class SSAGraphBuilder { virtual ~SSAGraphBuilder() {} virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; - std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program) final; + std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc similarity index 96% rename from paddle/fluid/framework/details/graph_builder_factory.cc rename to paddle/fluid/framework/details/ssa_graph_builder_factory.cc index a04b9bb63c..b5e90d6b05 100644 --- a/paddle/fluid/framework/details/graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/graph_builder_factory.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include <fstream> #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h" diff --git a/paddle/fluid/framework/details/graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h similarity index 100% rename from paddle/fluid/framework/details/graph_builder_factory.h rename to paddle/fluid/framework/details/ssa_graph_builder_factory.h diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ce56f55e41..f1ab337070 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,8 +22,8 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif -#include "paddle/fluid/framework/details/graph_builder_factory.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -114,7 +114,7 @@ ParallelExecutor::ParallelExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, - builder_factory.Create()->Build(main_program))); + builder_factory.Create()->BuildAndCheck(main_program))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), From 1076e85135d3eadd97324c6d06bf4a6a30852148 Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Thu, 7 Jun 2018 17:01:41 +0800 Subject: [PATCH 3/4] refine logic --- paddle/fluid/framework/details/ssa_graph_builder.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index d24669a8f8..c4ee088507 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -135,9 +135,11 @@ bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { while (!pending_vars.empty()) { run_all_ops(ready_ops); + if (ready_vars.empty()) { return false; } + for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); for (auto *op : ready_var->pending_ops_) { From 0c851cab2294356dd292b9b4458379d1bde4eadd Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Fri, 8 Jun 2018 15:02:33 +0800 Subject: [PATCH 4/4] add SSA graph checker --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../framework/details/ssa_graph_builder.cc | 70 --------------- .../framework/details/ssa_graph_builder.h | 3 - .../details/ssa_graph_builder_factory.cc | 3 + .../framework/details/ssa_graph_checker.cc | 87 +++++++++++++++++++ .../framework/details/ssa_graph_checker.h | 44 ++++++++++ paddle/fluid/framework/parallel_executor.cc | 2 +- 7 files changed, 137 insertions(+), 75 deletions(-) create mode 100644 paddle/fluid/framework/details/ssa_graph_checker.cc create mode 100644 paddle/fluid/framework/details/ssa_graph_checker.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index ced063a097..dbd118d338 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -8,6 +8,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) +cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) @@ -30,7 +31,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) -cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) +cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index c4ee088507..88a21f4887 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -83,76 +83,6 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { op->AddOutput(dummy_leaf); } } - -std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck( - const ProgramDesc &program) { - std::unique_ptr<SSAGraph> graph = Build(program); - PADDLE_ENFORCE(IsValidGraph(graph.get())); - return std::move(graph); -} - -bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const { - std::unordered_map<OpHandleBase *, size_t> pending_ops; - std::unordered_set<VarHandleBase *> pending_vars; - std::unordered_set<VarHandleBase *> ready_vars; - std::unordered_set<OpHandleBase *> ready_ops; - - auto insert_pending_var = [&](VarHandleBase *var) { - pending_vars.insert(var); - if (var->generated_op_ == nullptr) { - ready_vars.emplace(var); - } - }; - - for (auto &var_map : graph->vars_) { - for (auto &name_pair : var_map) { - for (auto &version_pair : name_pair.second) { - insert_pending_var(version_pair.get()); - } - } - } - - for (auto &var : graph->dep_vars_) { - insert_pending_var(var.get()); - } - - for (auto &op : graph->ops_) { - if (op->Inputs().empty()) { - ready_ops.insert(op.get()); - } else { - pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); - } - } - - auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { - for (auto *op : set) { - for (auto out : op->Outputs()) { - ready_vars.emplace(out); - } - } - set.clear(); - }; - - while (!pending_vars.empty()) { - run_all_ops(ready_ops); - - if (ready_vars.empty()) { - return false; - } - - for (auto ready_var : ready_vars) { - pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { - auto &deps = --pending_ops[op]; - if (deps == 0) { - ready_ops.insert(op); - } - } - } - ready_vars.clear(); - } - return true; -} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index e99a988407..5fc12a44b5 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -31,8 +31,6 @@ class SSAGraphBuilder { virtual ~SSAGraphBuilder() {} virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; - std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program); - DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: @@ -50,7 +48,6 @@ class SSAGraphBuilder { const platform::Place &place, size_t place_offset); - bool IsValidGraph(const SSAGraph *graph) const; // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc index b5e90d6b05..b4b49d3de6 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" #include <fstream> #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/ssa_graph_checker.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h" namespace paddle { @@ -40,6 +41,8 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() { res.reset(new SSAGraghBuilderWithPrinter( std::move(fout), std::move(graphviz_printer), std::move(res))); } + res.reset(new SSAGraghBuilderWithChecker(std::move(res))); + return res; } } // namespace details diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc new file mode 100644 index 0000000000..da5428946e --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph.h" +#include <string> +#include "paddle/fluid/framework/details/ssa_graph_checker.h" + +namespace paddle { +namespace framework { +namespace details { + +bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { + std::unordered_map<OpHandleBase *, size_t> pending_ops; + std::unordered_set<VarHandleBase *> pending_vars; + std::unordered_set<VarHandleBase *> ready_vars; + std::unordered_set<OpHandleBase *> ready_ops; + + auto insert_pending_var = [&](VarHandleBase *var) { + pending_vars.insert(var); + if (var->generated_op_ == nullptr) { + ready_vars.emplace(var); + } + }; + + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + insert_pending_var(version_pair.get()); + } + } + } + + for (auto &var : graph->dep_vars_) { + insert_pending_var(var.get()); + } + + for (auto &op : graph->ops_) { + if (op->Inputs().empty()) { + ready_ops.insert(op.get()); + } else { + pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); + } + } + + auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { + for (auto *op : set) { + for (auto out : op->Outputs()) { + ready_vars.emplace(out); + } + } + set.clear(); + }; + + while (!pending_vars.empty()) { + run_all_ops(ready_ops); + + if (ready_vars.empty()) { + return false; + } + + for (auto ready_var : ready_vars) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = --pending_ops[op]; + if (deps == 0) { + ready_ops.insert(op); + } + } + } + ready_vars.clear(); + } + return true; +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h new file mode 100644 index 0000000000..542c4a1728 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + +namespace paddle { +namespace framework { +namespace details { +class SSAGraph; + +class SSAGraghBuilderWithChecker : public SSAGraphBuilder { + public: + explicit SSAGraghBuilderWithChecker( + std::unique_ptr<SSAGraphBuilder>&& builder) + : builder_(std::move(builder)) {} + + std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override { + auto graph = builder_->Build(program); + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return graph; + } + + bool IsValidGraph(const SSAGraph* graph) const; + + private: + std::unique_ptr<SSAGraphBuilder> builder_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f1ab337070..5d95dc214a 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -114,7 +114,7 @@ ParallelExecutor::ParallelExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, - builder_factory.Create()->BuildAndCheck(main_program))); + builder_factory.Create()->Build(main_program))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos),