|
|
|
@ -140,6 +140,8 @@ class VarBase {
|
|
|
|
|
}
|
|
|
|
|
inline bool IsStopGradient() const { return stop_gradient_; }
|
|
|
|
|
|
|
|
|
|
void RunBackward();
|
|
|
|
|
|
|
|
|
|
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
|
|
|
|
|
int pre_op_out_idx, bool pre_op_stop_gradient) {
|
|
|
|
|
pre_op_ = pre_op;
|
|
|
|
@ -150,22 +152,6 @@ class VarBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RunBackward() {
|
|
|
|
|
if (!pre_op_) return;
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "start backward";
|
|
|
|
|
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
operators::math::set_constant(
|
|
|
|
|
*(platform::DeviceContextPool::Instance().Get(
|
|
|
|
|
var_->GetMutable<framework::LoDTensor>()->place())),
|
|
|
|
|
grads_t, 1.0);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
grads_ ==
|
|
|
|
|
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
|
|
|
|
|
Autograd().RunBackward(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ClearGradient() {
|
|
|
|
|
VLOG(1) << "clear gradient of " << var_desc_->Name();
|
|
|
|
|
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
|
|
|
|
|