|
|
|
@ -365,51 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
// while operator could be renamed.
|
|
|
|
|
while_grad->SetAttr("original_output_grad", output_grads_list);
|
|
|
|
|
|
|
|
|
|
/* The following codes are used in eager deletion mode */
|
|
|
|
|
std::unordered_set<std::string> bwd_skip_vars;
|
|
|
|
|
if (framework::GetEagerDeletionThreshold() >= 0) {
|
|
|
|
|
std::unordered_set<std::string> fwd_skip_vars;
|
|
|
|
|
for (auto *op_desc : grad_block->AllOps()) {
|
|
|
|
|
auto skippable = [&](const std::string &name) {
|
|
|
|
|
return !grad_block->HasVar(name) &&
|
|
|
|
|
(fwd_block->HasVarRecursive(name) ||
|
|
|
|
|
parent_block->HasVarRecursive(name));
|
|
|
|
|
};
|
|
|
|
|
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
|
|
|
|
|
if (skippable(in_arg_name)) {
|
|
|
|
|
fwd_skip_vars.insert(in_arg_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
|
|
|
|
|
if (skippable(out_arg_name)) {
|
|
|
|
|
fwd_skip_vars.insert(out_arg_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!fwd_skip_vars.empty()) {
|
|
|
|
|
// FIXME(zjl): ugly const_cast here, maybe we should find a better way
|
|
|
|
|
// to modify forward while_op
|
|
|
|
|
auto &fwd_while_op = const_cast<framework::OpDesc &>(ForwardOp());
|
|
|
|
|
fwd_while_op.SetAttr(kSkipEagerDeletionVars,
|
|
|
|
|
std::vector<std::string>(fwd_skip_vars.begin(),
|
|
|
|
|
fwd_skip_vars.end()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Find backward skip vars
|
|
|
|
|
auto fwd_input = Input(kX);
|
|
|
|
|
for (size_t i = 0; i < igs.size(); ++i) {
|
|
|
|
|
if (igs[i] == framework::kEmptyVarName) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
bwd_skip_vars.insert(igs[i]);
|
|
|
|
|
bwd_skip_vars.insert(framework::GradVarName(fwd_input[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
while_grad->SetAttr(
|
|
|
|
|
kSkipEagerDeletionVars,
|
|
|
|
|
std::vector<std::string>(bwd_skip_vars.begin(), bwd_skip_vars.end()));
|
|
|
|
|
while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
|
|
|
|
|
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(while_grad);
|
|
|
|
|
}
|
|
|
|
|