|
|
|
@ -145,7 +145,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
prepared_op.func(framework::ExecutionContext(
|
|
|
|
|
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
|
|
|
|
|
|
|
|
|
|
std::set<std::string> grad_deps_var;
|
|
|
|
|
std::set<std::string> vars_saved_for_backward;
|
|
|
|
|
|
|
|
|
|
if (!stop_gradient) {
|
|
|
|
|
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
|
|
|
|
@ -166,7 +166,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
PADDLE_ENFORCE(fwd_var_it != vars.end());
|
|
|
|
|
// Forward inputs or outputs.
|
|
|
|
|
grad_in_vars.push_back(fwd_var_it->second->var_);
|
|
|
|
|
grad_deps_var.insert(it.first);
|
|
|
|
|
vars_saved_for_backward.insert(it.first);
|
|
|
|
|
} else {
|
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
@ -200,7 +200,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op->block_ = block;
|
|
|
|
|
return grad_deps_var;
|
|
|
|
|
return vars_saved_for_backward;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
|
|
|
|
|