Clear all parameters' gradient

test=develop
revert-15470-feature/imperative
minqiyang 7 years ago
parent 49a7fba848
commit 07822fef2c

@ -152,11 +152,13 @@ class VarBase {
void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
framework::LoDTensor& GradValue();

@ -52,8 +52,7 @@ class Layer(core.Layer):
def clear_gradients(self):
for p in self.parameters():
if not p._stop_gradient:
p._clear_gradient()
p._clear_gradient()
def _build_once(self, inputs):
pass

Loading…
Cancel
Save