From 27533b64237528e0de0166b45a322d4ab6fee276 Mon Sep 17 00:00:00 2001
From: Yu Yang <yuyang18@baidu.com>
Date: Wed, 4 Apr 2018 12:56:32 +0800
Subject: [PATCH] Fix Leaf Ops in Graph

All leaves must be variables. When all variables are ready, the
execution will be completed. If a operator has no output, the `Op::Run`
might not be started when the execution of graph has been complete.
---
 .../framework/details/multi_devices_graph_builder.cc  |  8 ++++++++
 paddle/fluid/framework/details/ssa_graph_builder.cc   | 11 +++++++++++
 paddle/fluid/framework/details/ssa_graph_builder.h    |  8 +++++---
 .../framework/details/threaded_ssa_graph_executor.cc  |  8 ++++++--
 4 files changed, 30 insertions(+), 5 deletions(-)

diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index c277bd7cb6..128a5344fb 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -21,6 +21,9 @@
 #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
 #endif
 
+#include <string>
+#include <vector>
+
 namespace paddle {
 namespace framework {
 namespace details {
@@ -168,6 +171,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
    */
   PolishGraphToSupportDataHazards(&result);
 
+  /*
+   * Only variables should be the leaves of graph.
+   */
+  AddOutputToLeafOps(&result);
+
   if (VLOG_IS_ON(10)) {
     std::ostringstream sout;
     PrintGraphviz(*graph, sout);
diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc
index 361ba6d397..0a4febd22f 100644
--- a/paddle/fluid/framework/details/ssa_graph_builder.cc
+++ b/paddle/fluid/framework/details/ssa_graph_builder.cc
@@ -136,6 +136,17 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
 
   sout << "}\n";
 }
+
+void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
+  for (auto &op : graph->ops_) {
+    if (!op->outputs_.empty()) {
+      continue;
+    }
+    auto *dummy_leaf = new DummyVarHandle();
+    graph->dep_vars_.emplace(dummy_leaf);
+    op->AddOutput(dummy_leaf);
+  }
+}
 }  // namespace details
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h
index bf20e7164a..be1f0460e4 100644
--- a/paddle/fluid/framework/details/ssa_graph_builder.h
+++ b/paddle/fluid/framework/details/ssa_graph_builder.h
@@ -14,13 +14,13 @@
 
 #pragma once
 
+#include <memory>
+#include <string>
+
 #include "paddle/fluid/framework/details/ssa_graph.h"
 #include "paddle/fluid/framework/program_desc.h"
 #include "paddle/fluid/platform/place.h"
 
-#include <memory>
-#include <string>
-
 namespace paddle {
 namespace framework {
 namespace details {
@@ -52,6 +52,8 @@ class SSAGraphBuilder {
                              const std::string &each_var_name,
                              const platform::Place &place, size_t place_offset);
 
+  static void AddOutputToLeafOps(SSAGraph *graph);
+
   static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
 };
 }  // namespace details
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index 1f96b9dc62..596e573186 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
 
   // Step 2. Insert FetchOps
   std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
-  std::vector<DummyVarHandle> dummy_vars;
   FeedFetchList fetch_data(fetch_tensors.size());
 
   std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
@@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
     }
   }
 
+  std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
   for (size_t i = 0; i < fetch_tensors.size(); ++i) {
     auto &var_name = fetch_tensors[i];
     auto &vars = fetched_vars.at(var_name);
     auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
     fetch_ops.emplace_back(op);
 
-    // FIXME: Use new device context
     for (auto &p : places_) {
       op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
     }
@@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
     for (auto *var : vars) {
       op->AddInput(var);
     }
+
+    auto *fetch_dummy = new DummyVarHandle();
+    op->AddOutput(fetch_dummy);
+    fetch_dependencies.emplace(fetch_dummy);
+    InsertPendingVar(*fetch_dummy);
     InsertPendingOp(*op);
   }