|
|
|
|
@ -44,16 +44,12 @@ class Autograd {
|
|
|
|
|
public:
|
|
|
|
|
explicit Autograd(framework::Scope* scope) : scope_(scope) {}
|
|
|
|
|
|
|
|
|
|
void RunBackward(VarBase* var, framework::Variable* grad) {
|
|
|
|
|
if (!var->pre_op_) {
|
|
|
|
|
var->ApplyGrad(scope_, grad);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
void RunBackward(VarBase* var) {
|
|
|
|
|
PADDLE_ENFORCE(var->pre_op_->op_desc_);
|
|
|
|
|
// TODO(panyx0718): Only create vars that "require_grad"
|
|
|
|
|
std::vector<Variable*> op_grads =
|
|
|
|
|
CreateOpGrads(var->pre_op_->output_vars_->size());
|
|
|
|
|
op_grads[var->pre_op_out_idx_] = grad;
|
|
|
|
|
op_grads[var->pre_op_out_idx_] = var->grads_;
|
|
|
|
|
|
|
|
|
|
std::deque<std::pair<OpBase*, std::vector<Variable*>>> ready;
|
|
|
|
|
ready.push_back(std::make_pair(var->pre_op_, op_grads));
|
|
|
|
|
@ -238,8 +234,6 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
|
|
|
|
|
framework::Variable* var = scope->FindVar(outvar);
|
|
|
|
|
LOG(ERROR) << "apply grad " << outvar << " with origin "
|
|
|
|
|
<< origin_var;
|
|
|
|
|
// TODO(panyx0718): Accumulate.
|
|
|
|
|
// origin_in_var->grads_ = var;
|
|
|
|
|
origin_in_var->ApplyGrad(scope, var);
|
|
|
|
|
ret[i] = var;
|
|
|
|
|
// TODO(panyx0718): There might be 2 var with the same name. We
|
|
|
|
|
@ -254,15 +248,11 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VarBase::RunBackward(framework::Scope* scope) {
|
|
|
|
|
// TODO(panyx0718): Might not be 0th, need to detect.
|
|
|
|
|
grads_ = CreateVariable(pre_op_->grad_op_desc_->InputArgumentNames()[0],
|
|
|
|
|
grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()),
|
|
|
|
|
var_->Get<framework::LoDTensor>().dims(), 1.0, scope,
|
|
|
|
|
false);
|
|
|
|
|
framework::Variable* grad =
|
|
|
|
|
CreateVariable("init@imperative_grad",
|
|
|
|
|
var_->Get<framework::LoDTensor>().dims(), 1.0, scope);
|
|
|
|
|
|
|
|
|
|
Autograd(scope).RunBackward(this, grad);
|
|
|
|
|
if (!pre_op_) return;
|
|
|
|
|
Autograd(scope).RunBackward(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace imperative
|
|
|
|
|
|