|
|
|
@ -104,7 +104,7 @@ class Momentum(Optimizer):
|
|
|
|
|
raise ValueError("learning_rate is not set")
|
|
|
|
|
if momentum is None:
|
|
|
|
|
raise ValueError("momentum is not set")
|
|
|
|
|
predicate = lambda regular: isinstance(regular, L2DecayRegularizer)
|
|
|
|
|
predicate = lambda regular: isinstance(regular, (L2DecayRegularizer, float))
|
|
|
|
|
py_regular = None if predicate(weight_decay) else weight_decay
|
|
|
|
|
super(Momentum, self).__init__(
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
@ -120,6 +120,9 @@ class Momentum(Optimizer):
|
|
|
|
|
if (isinstance(weight_decay, L2DecayRegularizer)):
|
|
|
|
|
self._regularization_method = "l2_decay"
|
|
|
|
|
self._regularization_coeff = weight_decay._regularization_coeff
|
|
|
|
|
if (isinstance(weight_decay, float)):
|
|
|
|
|
self._regularization_method = "l2_decay"
|
|
|
|
|
self._regularization_coeff = weight_decay
|
|
|
|
|
self._multi_precision = multi_precision
|
|
|
|
|
self._rescale_grad = rescale_grad
|
|
|
|
|
self._master_weights = {}
|
|
|
|
|