From 7be57de9434053e7aa2e7b1d78da62ee1cb41ba7 Mon Sep 17 00:00:00 2001
From: fengjiayi <fengjiayi@baidu.com>
Date: Tue, 2 Jan 2018 16:55:51 +0800
Subject: [PATCH 1/3] enhance no_grad_var handling

---
 python/paddle/v2/fluid/backward.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py
index f11c83f59c..43e9abc354 100644
--- a/python/paddle/v2/fluid/backward.py
+++ b/python/paddle/v2/fluid/backward.py
@@ -57,6 +57,8 @@ def _all_in_set_(cands, s):
     """
     Test if all elements of 'cands' are in set 's'
     """
+    if len(cands) == 0:
+        return False
     for c in cands:
         if not c in s:
             return False
@@ -138,10 +140,20 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
         1. all outputs of the grad op are in 'no_grad_set'
         2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
     """
+
+    def _op_can_be_removed_(op_desc, no_grad_set):
+        if _all_in_set_(op_desc.output_arg_names(), no_grad_set):
+            return True
+        if _all_in_set_(
+                filter(lambda name: name.find(core.grad_var_suffix()) != -1,
+                       op_desc.input_arg_names()), no_grad_set):
+            no_grad_set.union(op_desc.output_arg_names())
+            return True
+        return False
+
     # Remove ops whose outputs are all in no_grad_dict
     op_descs = filter(
-        lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
-        op_descs)
+        lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
     # Insert fill_zeros_like_op
     to_insert = []
     for idx, op_desc in enumerate(op_descs):

From 8d4a607fb35a6eb9b5eacf9999f955bde911e2ad Mon Sep 17 00:00:00 2001
From: fengjiayi <fengjiayi@baidu.com>
Date: Tue, 2 Jan 2018 17:30:40 +0800
Subject: [PATCH 2/3] update backward doc

---
 doc/design/backward.md             | 6 ++++--
 python/paddle/v2/fluid/backward.py | 2 +-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/doc/design/backward.md b/doc/design/backward.md
index 35f03692bb..20fda7a98f 100644
--- a/doc/design/backward.md
+++ b/doc/design/backward.md
@@ -106,9 +106,11 @@ See function `_addup_repetitive_outputs_` in `backward.py` for implementation de
 
 In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Apparently, when all the outputs of some `grad_op` are marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass. 
 
-But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros. 
+Another situation is all the gradient inputs of some `grad_op` are marked as *no_gradient*, which means all of them can be considered as zeros. For `grad_op`s are in essence the propagation of gradients, all the outputs are definitely zeros when all gradient inputs are zeros. Therefore the `grad_op` can also be skipped.
 
-This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False). 
+It should be noted that all these zero gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros. 
+
+This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes who can be skipped and inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False). 
 
 ### Creating Backward Variables
 
diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py
index 43e9abc354..a1be768daa 100644
--- a/python/paddle/v2/fluid/backward.py
+++ b/python/paddle/v2/fluid/backward.py
@@ -138,7 +138,7 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
     Remove unnecessary grad ops
     A grad op can be removed in two cases:
         1. all outputs of the grad op are in 'no_grad_set'
-        2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
+        2. all grad inputs of the grad op are in 'no_grad_set'
     """
 
     def _op_can_be_removed_(op_desc, no_grad_set):

From 33e75201e9d3c14945bbe556267b8bae069de327 Mon Sep 17 00:00:00 2001
From: fengjiayi <fengjiayi@baidu.com>
Date: Tue, 2 Jan 2018 20:00:00 +0800
Subject: [PATCH 3/3] fix bugs

---
 python/paddle/v2/fluid/backward.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py
index a1be768daa..ac60bf5436 100644
--- a/python/paddle/v2/fluid/backward.py
+++ b/python/paddle/v2/fluid/backward.py
@@ -142,12 +142,13 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
     """
 
     def _op_can_be_removed_(op_desc, no_grad_set):
-        if _all_in_set_(op_desc.output_arg_names(), no_grad_set):
+        out_arg_names = op_desc.output_arg_names()
+        if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
             return True
         if _all_in_set_(
                 filter(lambda name: name.find(core.grad_var_suffix()) != -1,
                        op_desc.input_arg_names()), no_grad_set):
-            no_grad_set.union(op_desc.output_arg_names())
+            no_grad_set.union(out_arg_names)
             return True
         return False
 
@@ -296,7 +297,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
                     block_no_grad_set.add(_append_grad_suffix_(var.name))
             no_grad_dict[block.idx] = block_no_grad_set
     elif isinstance(no_grad_set, set):
-        no_grad_dict = {0: no_grad_set}
+        no_grad_dict = {
+            0: set([_append_grad_suffix_(name) for name in no_grad_set])
+        }
     else:
         raise ValueError("'no_grad_set' should be a set or None.")