|
|
|
@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
for (auto &o : Output(kOutputs)) {
|
|
|
|
|
block_ins.insert(o);
|
|
|
|
|
}
|
|
|
|
|
std::unordered_set<std::string> extra_inputs;
|
|
|
|
|
std::unordered_set<std::string> output_grads;
|
|
|
|
|
for (const auto *op : grad_block->AllOps()) {
|
|
|
|
|
for (auto &input_name : op->InputArgumentNames()) {
|
|
|
|
|
// If the input of Op has been recorded or is generated by the forward
|
|
|
|
|
// block, do not make it as input again.
|
|
|
|
|
|
|
|
|
|
// The input is located in I/O or other op's outputs or the variable is
|
|
|
|
|
// located in grad_block's parents
|
|
|
|
|
if (block_ins.find(input_name) != block_ins.end() ||
|
|
|
|
|
fwd_block->FindVar(input_name) != nullptr ||
|
|
|
|
|
parent_block->FindVar(input_name) != nullptr) {
|
|
|
|
|
(fwd_block->FindVarRecursive(input_name) != nullptr ||
|
|
|
|
|
parent_block->FindVarRecursive(input_name) != nullptr)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
extra_inputs.insert(input_name);
|
|
|
|
|
output_grads.insert(input_name);
|
|
|
|
|
}
|
|
|
|
|
for (auto &output_name : op->OutputArgumentNames()) {
|
|
|
|
|
block_ins.insert(output_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> extra_inputs_list;
|
|
|
|
|
extra_inputs_list.resize(extra_inputs.size());
|
|
|
|
|
std::copy(extra_inputs.begin(), extra_inputs.end(),
|
|
|
|
|
extra_inputs_list.begin());
|
|
|
|
|
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
|
|
|
|
|
std::vector<std::string> output_grads_list;
|
|
|
|
|
output_grads_list.resize(output_grads.size());
|
|
|
|
|
std::copy(output_grads.begin(), output_grads.end(),
|
|
|
|
|
output_grads_list.begin());
|
|
|
|
|
while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
|
|
|
|
|
|
|
|
|
|
while_grad->SetAttrMap(this->Attrs());
|
|
|
|
|
while_grad->SetBlockAttr(kStepBlock, *grad_block);
|
|
|
|
|
// record the original output gradient names, since the gradient name of
|
|
|
|
|
// while operator could be renamed.
|
|
|
|
|
while_grad->SetAttr("original_output_grad", extra_inputs_list);
|
|
|
|
|
while_grad->SetAttr("original_output_grad", output_grads_list);
|
|
|
|
|
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(while_grad);
|
|
|
|
|
}
|
|
|
|
|