|
|
|
|
@ -152,11 +152,13 @@ class VarBase {
|
|
|
|
|
|
|
|
|
|
void ClearGradient() {
|
|
|
|
|
VLOG(1) << "clear gradient of " << var_desc_->Name();
|
|
|
|
|
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
operators::math::set_constant(
|
|
|
|
|
*(platform::DeviceContextPool::Instance().Get(
|
|
|
|
|
grads_->var_->Get<framework::LoDTensor>().place())),
|
|
|
|
|
grads_t, 0.0);
|
|
|
|
|
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
|
|
|
|
|
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
operators::math::set_constant(
|
|
|
|
|
*(platform::DeviceContextPool::Instance().Get(
|
|
|
|
|
grads_->var_->Get<framework::LoDTensor>().place())),
|
|
|
|
|
grads_t, 0.0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::LoDTensor& GradValue();
|
|
|
|
|
|