|
|
|
@ -40,7 +40,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|
|
|
|
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
|
|
|
|
|
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
|
|
|
|
lr (Tensor): Learning rate.
|
|
|
|
|
weight_decay (Number): Weight decay. Should be in range [0.0, 1.0].
|
|
|
|
|
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
|
|
|
|
|
param (Tensor): Parameters.
|
|
|
|
|
m (Tensor): m value of parameters.
|
|
|
|
|
v (Tensor): v value of parameters.
|
|
|
|
@ -200,8 +200,8 @@ class Adam(Optimizer):
|
|
|
|
|
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
|
|
|
|
If True, updates the gradients using NAG.
|
|
|
|
|
If False, updates the gradients without using NAG. Default: False.
|
|
|
|
|
weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
|
|
|
|
|
loss_scale (float): A floating point value for the loss scale. Should be not less than 1.0. Default: 1.0.
|
|
|
|
|
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
|
|
|
|
|
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
|
|
|
@ -318,7 +318,7 @@ class AdamWeightDecay(Optimizer):
|
|
|
|
|
Should be in range (0.0, 1.0).
|
|
|
|
|
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
|
|
|
|
Should be greater than 0.
|
|
|
|
|
weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
|
|
|
|
|
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
|
|
|
|