diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc
index 65d827e0e0..3b78dd128f 100644
--- a/paddle/operators/while_op.cc
+++ b/paddle/operators/while_op.cc
@@ -211,59 +211,64 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
 
  protected:
   std::unique_ptr<framework::OpDesc> Apply() const override {
-    auto *grad = new framework::OpDesc();
-    grad->SetType("while_grad");
-    grad->SetInput(kX, Input(kX));
+    auto *while_grad = new framework::OpDesc();
+    while_grad->SetType("while_grad");
+    while_grad->SetInput(kX, Input(kX));
+    while_grad->SetInput(kOutputs, Output(kOutputs));
+    while_grad->SetInput(kStepScopes, Output(kStepScopes));
+
+    auto *grad_block = this->grad_block_[0];
+    auto *fwd_block = grad_block->ParentBlock();
+    // auto *parent_block = fwd_block->ParentBlock();
 
     // Not all of IGs will be generated by inner gradient operators of while op.
     // Ignore IGs that is not generated by the inside block.
-    auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
-    std::unordered_set<std::string> all_outs;
-    for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
-      for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) {
-        all_outs.insert(oname);
+    std::unordered_set<std::string> inner_op_outputs;
+    LOG(INFO) << "FUCK1";
+    for (const auto *op : grad_block->AllOps()) {
+      for (auto &oname : op->OutputArgumentNames()) {
+        inner_op_outputs.insert(oname);
       }
     }
+    LOG(INFO) << "FUCK2";
+    auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
     for (auto &each_ig : igs) {
-      if (all_outs.find(each_ig) == all_outs.end()) {
+      if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
         VLOG(10) << "Ignore " << each_ig;
         each_ig = framework::kEmptyVarName;
       }
     }
-
-    grad->SetOutput(framework::GradVarName(kX), igs);
-
-    grad->SetInput(kOutputs, Output(kOutputs));
+    while_grad->SetOutput(framework::GradVarName(kX), igs);
 
     // OG should be re-calculated by step blocks, since many outputs of while op
     // do not need to calculate gradients.
     std::unordered_set<std::string> block_ins;
-    auto *fwd_block = this->grad_block_[0]->ParentBlock();
-    {
-      for (auto &p : Input(kX)) {
-        block_ins.insert(p);
-      }
-      for (auto &o : Output(kOutputs)) {
-        block_ins.insert(o);
-      }
-    }
+    std::copy(Input(kX).begin(), Input(kX).end(),
+              std::inserter(block_ins, block_ins.end()));
+    std::copy(Output(kOutputs).begin(), Output(kOutputs).end(),
+              std::inserter(block_ins, block_ins.end()));
+
     std::unordered_set<std::string> extra_inputs;
-    for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
-      for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) {
-        if (block_ins.find(input_name) != block_ins.end()) {
+    for (const auto *op : grad_block->AllOps()) {
+      for (auto &input_name : op->InputArgumentNames()) {
+        // If the input of Op has been recorded or is generated by the forward
+        // block, do not make it as input again.
+        if (block_ins.find(input_name) != block_ins.end() ||
+            fwd_block->FindVar(input_name) != nullptr) {
           continue;
         }
 
-        // If the input of Op is generated by the forward block, do not make it
-        // as input again.
-        if (fwd_block->FindVar(input_name) != nullptr) {
+        /*
+        if (parent_block->FindVarRecursive(input_name) == nullptr) {
+          VLOG(5) << "WARNING! Variable '" << input_name
+                  << "' is the input of '" << op->Type()
+                  << "'. But can not be found in any block.";
           continue;
         }
-
+        */
         extra_inputs.insert(input_name);
       }
-
-      for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) {
+      for (auto &output_name : op->OutputArgumentNames()) {
         block_ins.insert(output_name);
       }
     }
@@ -272,15 +277,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
     extra_inputs_list.resize(extra_inputs.size());
     std::copy(extra_inputs.begin(), extra_inputs.end(),
               extra_inputs_list.begin());
-    grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
-    grad->SetInput(kStepScopes, Output(kStepScopes));
-    grad->SetAttrMap(this->Attrs());
-    grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
+    while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
+
+    while_grad->SetAttrMap(this->Attrs());
+    while_grad->SetBlockAttr(kStepBlock, *grad_block);
     // record the original output gradient names, since the gradient name of
     // while operator could be renamed.
-    grad->SetAttr("original_output_grad", extra_inputs_list);
+    while_grad->SetAttr("original_output_grad", extra_inputs_list);
 
-    return std::unique_ptr<framework::OpDesc>(grad);
+    return std::unique_ptr<framework::OpDesc>(while_grad);
   }
 };