From 9ac785be396bd21d3f152a299f5fa7cb5e268e08 Mon Sep 17 00:00:00 2001
From: chengduoZH <zhaochengduo@163.com>
Date: Thu, 7 Jun 2018 15:40:58 +0800
Subject: [PATCH] check graph's validation

---
 .../details/multi_devices_graph_builder.cc    |  1 -
 .../framework/details/ssa_graph_builder.cc    | 70 ++++++++++++++++++-
 .../framework/details/ssa_graph_builder.h     |  3 +
 .../details/threaded_ssa_graph_executor.cc    |  1 +
 4 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 0c4d369e88..81d5b079b8 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -272,7 +272,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
    * Only variables should be the leaves of graph.
    */
   AddOutputToLeafOps(&result);
-
   return std::unique_ptr<SSAGraph>(graph);
 }
 
diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc
index 211113c797..d70f95a9f5 100644
--- a/paddle/fluid/framework/details/ssa_graph_builder.cc
+++ b/paddle/fluid/framework/details/ssa_graph_builder.cc
@@ -11,8 +11,8 @@
 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 // See the License for the specific language governing permissions and
 // limitations under the License.
-
 #include "paddle/fluid/framework/details/ssa_graph_builder.h"
+#include <utility>
 
 namespace paddle {
 namespace framework {
@@ -83,6 +83,74 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
     op->AddOutput(dummy_leaf);
   }
 }
+
+std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck(
+    const ProgramDesc &program) final {
+  std::unique_ptr<SSAGraph> graph = Build(program);
+  PADDLE_ENFORCE(IsValidGraph(graph.get()));
+  return std::move(graph);
+}
+
+bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const {
+  std::unordered_map<OpHandleBase *, size_t> pending_ops;
+  std::unordered_set<VarHandleBase *> pending_vars;
+  std::unordered_set<VarHandleBase *> ready_vars;
+  std::unordered_set<OpHandleBase *> ready_ops;
+
+  auto insert_pending_var = [&](VarHandleBase *var) {
+    pending_vars.insert(var);
+    if (var->generated_op_ == nullptr) {
+      ready_vars.emplace(var);
+    }
+  };
+
+  for (auto &var_map : graph->vars_) {
+    for (auto &name_pair : var_map) {
+      for (auto &version_pair : name_pair.second) {
+        insert_pending_var(version_pair.get());
+      }
+    }
+  }
+
+  for (auto &var : graph->dep_vars_) {
+    insert_pending_var(var.get());
+  }
+
+  for (auto &op : graph->ops_) {
+    if (op->Inputs().empty()) {
+      ready_ops.insert(op.get());
+    } else {
+      pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
+    }
+  }
+
+  auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
+    for (auto *op : set) {
+      for (auto out : op->Outputs()) {
+        ready_vars.emplace(out);
+      }
+    }
+    set.clear();
+  };
+
+  while (!pending_vars.empty()) {
+    run_all_ops(ready_ops);
+    if (ready_vars.empty()) {
+      return false;
+    }
+    for (auto ready_var : ready_vars.) {
+      pending_vars.erase(ready_var);
+      for (auto *op : ready_var->pending_ops_) {
+        auto &deps = --pending_ops[op];
+        if (deps == 0) {
+          ready_ops.insert(op);
+        }
+      }
+    }
+    ready_vars.clear();
+  }
+  return true;
+}
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h
index 5fc12a44b5..da9298ac8d 100644
--- a/paddle/fluid/framework/details/ssa_graph_builder.h
+++ b/paddle/fluid/framework/details/ssa_graph_builder.h
@@ -31,6 +31,8 @@ class SSAGraphBuilder {
   virtual ~SSAGraphBuilder() {}
   virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
 
+  std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program) final;
+
   DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
 
  protected:
@@ -48,6 +50,7 @@ class SSAGraphBuilder {
                                                const platform::Place &place,
                                                size_t place_offset);
 
+  bool IsValidGraph(const SSAGraph *graph) const;
   // Add an output variable (each_var_name, place, place_offset) to op_handle,
   // which belongs to graph
   static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index 496fadd04d..bcbf573626 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
     ready_vars->Push(var);
   }
 }
+
 void ThreadedSSAGraphExecutor::RunOp(
     BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
   auto op_run = [ready_var_q, op, this] {