|
|
|
@ -49,7 +49,15 @@ class Optimizer(Cell):
|
|
|
|
|
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
|
|
|
|
|
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
|
|
|
|
Iterable or a Tensor and the dims of the Tensor is 1,
|
|
|
|
|
use dynamic learning rate, then the i-th step will
|
|
|
|
|
take the i-th value as the learning rate.
|
|
|
|
|
When the learning_rate is float or learning_rate is a Tensor
|
|
|
|
|
but the dims of the Tensor is 0, use fixed learning rate.
|
|
|
|
|
Other cases are not supported. Should be greater than 0.
|
|
|
|
|
If the type of `learning_rate` input is int, it will be
|
|
|
|
|
converted to float.
|
|
|
|
|
parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
|
|
|
|
|
updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`,
|
|
|
|
|
the "params", "lr" and "weight_decay" are the keys can be parsed.
|
|
|
|
@ -96,6 +104,8 @@ class Optimizer(Cell):
|
|
|
|
|
self.is_group = False
|
|
|
|
|
self.is_group_lr = False
|
|
|
|
|
self.loss_scale = loss_scale
|
|
|
|
|
if isinstance(learning_rate, int):
|
|
|
|
|
learning_rate = float(learning_rate)
|
|
|
|
|
if isinstance(learning_rate, float):
|
|
|
|
|
self.dynamic_lr = False
|
|
|
|
|
self.gather = None
|
|
|
|
|