|
|
|
@ -133,11 +133,11 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
grad_in_vars.push_back(fwd_var_it->second->var_);
|
|
|
|
|
} else {
|
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
|
if (!var->grads_->IsInitialized()) {
|
|
|
|
|
InitVar(var->var_, var->grads_);
|
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
|
InitVar(var->var_, var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
// Douts.
|
|
|
|
|
grad_in_vars.push_back(var->grads_);
|
|
|
|
|
grad_in_vars.push_back(var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -149,10 +149,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
auto var_it = grad_to_var->find(grad_outvar);
|
|
|
|
|
PADDLE_ENFORCE(var_it != grad_to_var->end());
|
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
|
if (!var->grads_->IsInitialized()) {
|
|
|
|
|
InitVar(var->var_, var->grads_);
|
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
|
InitVar(var->var_, var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
grad_out_vars.push_back(var->grads_);
|
|
|
|
|
grad_out_vars.push_back(var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -194,13 +194,13 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
|
|
|
|
|
grad_input_vars.push_back(out->var_);
|
|
|
|
|
}
|
|
|
|
|
for (VarBase* out : outputs) {
|
|
|
|
|
grad_input_vars.push_back(out->grads_);
|
|
|
|
|
grad_input_vars.push_back(out->grads_->var_);
|
|
|
|
|
if (!grad_input_vars.back()->IsInitialized()) {
|
|
|
|
|
InitVar(out->var_, grad_input_vars.back());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (const VarBase* inp : inputs) {
|
|
|
|
|
grad_output_vars.push_back(inp->grads_);
|
|
|
|
|
grad_output_vars.push_back(inp->grads_->var_);
|
|
|
|
|
if (!grad_output_vars.back()->IsInitialized()) {
|
|
|
|
|
InitVar(inp->var_, grad_output_vars.back());
|
|
|
|
|
}
|
|
|
|
|