|
|
@ -45,15 +45,15 @@ class Variable {
|
|
|
|
void InitializeVariable();
|
|
|
|
void InitializeVariable();
|
|
|
|
|
|
|
|
|
|
|
|
VariableHandle Grad() {
|
|
|
|
VariableHandle Grad() {
|
|
|
|
if (grad_ == nullptr) {
|
|
|
|
if (grad_.expired()) {
|
|
|
|
grad_.reset(new Variable(desc_.Name(), true));
|
|
|
|
VariableHandle new_grad(new Variable(desc_.Name(), true));
|
|
|
|
|
|
|
|
grad_ = new_grad;
|
|
|
|
|
|
|
|
return new_grad;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return VariableHandle(grad_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return grad_;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ResetGrad() { grad_ = nullptr; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Stochastic Gradient Descent with Momentum
|
|
|
|
// Stochastic Gradient Descent with Momentum
|
|
|
|
// VariableHandle Momentum ();
|
|
|
|
// VariableHandle Momentum ();
|
|
|
|
|
|
|
|
|
|
|
@ -79,7 +79,7 @@ class Variable {
|
|
|
|
framework::VarDesc desc_;
|
|
|
|
framework::VarDesc desc_;
|
|
|
|
framework::Variable var_;
|
|
|
|
framework::Variable var_;
|
|
|
|
|
|
|
|
|
|
|
|
VariableHandle grad_;
|
|
|
|
std::weak_ptr<Variable> grad_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|