|
|
@ -26,10 +26,10 @@ from mindspore._checkparam import Validator as validator
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
from .optimizer import Optimizer
|
|
|
|
from .optimizer import Optimizer
|
|
|
|
|
|
|
|
|
|
|
|
adam_opt = C.MultitypeFuncGraph("adam_opt")
|
|
|
|
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
|
|
|
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
|
|
|
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
|
|
|
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Update parameters.
|
|
|
|
Update parameters.
|
|
|
@ -101,8 +101,8 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
|
|
|
|
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
|
|
|
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
|
|
|
|
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tuple",
|
|
|
|
"Tensor", "Tensor", "Tensor")
|
|
|
|
"Tensor", "Tensor", "Tensor")
|
|
|
|
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
|
|
|
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
|
|
|
moment1, moment2):
|
|
|
|
moment1, moment2):
|
|
|
|
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
|
|
|
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
|
|
@ -112,8 +112,8 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2
|
|
|
|
return success
|
|
|
|
return success
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
|
|
|
|
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
|
|
|
|
"Tensor", "Tensor", "Tensor")
|
|
|
|
"Tensor", "Tensor", "Tensor")
|
|
|
|
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
|
|
|
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
|
|
|
|
moment1, moment2):
|
|
|
|
moment1, moment2):
|
|
|
|
"""Apply adam optimizer to the weight parameter using Tensor."""
|
|
|
|
"""Apply adam optimizer to the weight parameter using Tensor."""
|
|
|
@ -261,11 +261,11 @@ class Adam(Optimizer):
|
|
|
|
beta2_power = self.beta2_power * self.beta2
|
|
|
|
beta2_power = self.beta2_power * self.beta2
|
|
|
|
self.beta2_power = beta2_power
|
|
|
|
self.beta2_power = beta2_power
|
|
|
|
if self.is_group_lr:
|
|
|
|
if self.is_group_lr:
|
|
|
|
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
|
|
|
|
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
|
|
|
|
self.beta1, self.beta2, self.eps),
|
|
|
|
self.beta1, self.beta2, self.eps),
|
|
|
|
lr, gradients, params, moment1, moment2)
|
|
|
|
lr, gradients, params, moment1, moment2)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
success = self.map_(F.partial(adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
|
|
|
|
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, beta1_power, beta2_power,
|
|
|
|
self.beta1, self.beta2, self.eps, lr),
|
|
|
|
self.beta1, self.beta2, self.eps, lr),
|
|
|
|
gradients, params, moment1, moment2)
|
|
|
|
gradients, params, moment1, moment2)
|
|
|
|
return success
|
|
|
|
return success
|
|
|
@ -328,7 +328,7 @@ class AdamWeightDecay(Optimizer):
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, gradients):
|
|
|
|
def construct(self, gradients):
|
|
|
|
lr = self.get_lr()
|
|
|
|
lr = self.get_lr()
|
|
|
|
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
|
|
|
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
|
|
|
self.weight_decay_tensor),
|
|
|
|
self.weight_decay_tensor),
|
|
|
|
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
|
|
|
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
|
|
|
|
|
|
|
|
|
|
@ -424,7 +424,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|
|
|
warmup_lr = self.start_learning_rate * warmup_percent
|
|
|
|
warmup_lr = self.start_learning_rate * warmup_percent
|
|
|
|
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
|
|
|
|
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
|
|
|
|
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
|
|
|
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
|
|
|
|
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
|
|
|
|
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
|
|
|
self.weight_decay_tensor),
|
|
|
|
self.weight_decay_tensor),
|
|
|
|
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
|
|
|
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
|
|
|
|
|
|
|
|
|
|
|