|  |  |  | @ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto &o : Output(kOutputs)) { | 
			
		
	
		
			
				
					|  |  |  |  |       block_ins.insert(o); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     std::unordered_set<std::string> extra_inputs; | 
			
		
	
		
			
				
					|  |  |  |  |     std::unordered_set<std::string> output_grads; | 
			
		
	
		
			
				
					|  |  |  |  |     for (const auto *op : grad_block->AllOps()) { | 
			
		
	
		
			
				
					|  |  |  |  |       for (auto &input_name : op->InputArgumentNames()) { | 
			
		
	
		
			
				
					|  |  |  |  |         // If the input of Op has been recorded or is generated by the forward
 | 
			
		
	
		
			
				
					|  |  |  |  |         // block, do not make it as input again.
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         // The input is located in I/O or other op's outputs or the variable is
 | 
			
		
	
		
			
				
					|  |  |  |  |         // located in grad_block's parents
 | 
			
		
	
		
			
				
					|  |  |  |  |         if (block_ins.find(input_name) != block_ins.end() || | 
			
		
	
		
			
				
					|  |  |  |  |             fwd_block->FindVar(input_name) != nullptr || | 
			
		
	
		
			
				
					|  |  |  |  |             parent_block->FindVar(input_name) != nullptr) { | 
			
		
	
		
			
				
					|  |  |  |  |             (fwd_block->FindVarRecursive(input_name) != nullptr || | 
			
		
	
		
			
				
					|  |  |  |  |              parent_block->FindVarRecursive(input_name) != nullptr)) { | 
			
		
	
		
			
				
					|  |  |  |  |           continue; | 
			
		
	
		
			
				
					|  |  |  |  |         } | 
			
		
	
		
			
				
					|  |  |  |  |         extra_inputs.insert(input_name); | 
			
		
	
		
			
				
					|  |  |  |  |         output_grads.insert(input_name); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |       for (auto &output_name : op->OutputArgumentNames()) { | 
			
		
	
		
			
				
					|  |  |  |  |         block_ins.insert(output_name); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     std::vector<std::string> extra_inputs_list; | 
			
		
	
		
			
				
					|  |  |  |  |     extra_inputs_list.resize(extra_inputs.size()); | 
			
		
	
		
			
				
					|  |  |  |  |     std::copy(extra_inputs.begin(), extra_inputs.end(), | 
			
		
	
		
			
				
					|  |  |  |  |               extra_inputs_list.begin()); | 
			
		
	
		
			
				
					|  |  |  |  |     while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); | 
			
		
	
		
			
				
					|  |  |  |  |     std::vector<std::string> output_grads_list; | 
			
		
	
		
			
				
					|  |  |  |  |     output_grads_list.resize(output_grads.size()); | 
			
		
	
		
			
				
					|  |  |  |  |     std::copy(output_grads.begin(), output_grads.end(), | 
			
		
	
		
			
				
					|  |  |  |  |               output_grads_list.begin()); | 
			
		
	
		
			
				
					|  |  |  |  |     while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     while_grad->SetAttrMap(this->Attrs()); | 
			
		
	
		
			
				
					|  |  |  |  |     while_grad->SetBlockAttr(kStepBlock, *grad_block); | 
			
		
	
		
			
				
					|  |  |  |  |     // record the original output gradient names, since the gradient name of
 | 
			
		
	
		
			
				
					|  |  |  |  |     // while operator could be renamed.
 | 
			
		
	
		
			
				
					|  |  |  |  |     while_grad->SetAttr("original_output_grad", extra_inputs_list); | 
			
		
	
		
			
				
					|  |  |  |  |     while_grad->SetAttr("original_output_grad", output_grads_list); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     return std::unique_ptr<framework::OpDesc>(while_grad); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
	
		
			
				
					|  |  |  | 
 |