improve performance of momentum (#30881)

revert-31068-fix_conv3d_windows
Zhang Ting 4 years ago committed by GitHub
parent 4b2d52a001
commit e97905c5fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -104,7 +104,7 @@ class Momentum(Optimizer):
raise ValueError("learning_rate is not set")
if momentum is None:
raise ValueError("momentum is not set")
predicate = lambda regular: isinstance(regular, L2DecayRegularizer)
predicate = lambda regular: isinstance(regular, (L2DecayRegularizer, float))
py_regular = None if predicate(weight_decay) else weight_decay
super(Momentum, self).__init__(
learning_rate=learning_rate,
@ -120,6 +120,9 @@ class Momentum(Optimizer):
if (isinstance(weight_decay, L2DecayRegularizer)):
self._regularization_method = "l2_decay"
self._regularization_coeff = weight_decay._regularization_coeff
if (isinstance(weight_decay, float)):
self._regularization_method = "l2_decay"
self._regularization_coeff = weight_decay
self._multi_precision = multi_precision
self._rescale_grad = rescale_grad
self._master_weights = {}

Loading…
Cancel
Save