|
|
|
@ -21,6 +21,7 @@ from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from .optimizer import grad_scale, Optimizer
|
|
|
|
|
|
|
|
|
|
lars_opt = C.MultitypeFuncGraph("lars_opt")
|
|
|
|
@ -41,6 +42,11 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f
|
|
|
|
|
|
|
|
|
|
return gradient
|
|
|
|
|
|
|
|
|
|
def _check_param_value(optimizer, epsilon, hyperpara, use_clip, prim_name):
|
|
|
|
|
validator.check_value_type("optimizer", optimizer, Optimizer, prim_name)
|
|
|
|
|
validator.check_value_type("epsilon", epsilon, [float], prim_name)
|
|
|
|
|
validator.check_value_type("hyperpara", hyperpara, [float], prim_name)
|
|
|
|
|
validator.check_value_type("use_clip", use_clip, [bool], prim_name)
|
|
|
|
|
|
|
|
|
|
class LARS(Optimizer):
|
|
|
|
|
"""
|
|
|
|
@ -79,9 +85,10 @@ class LARS(Optimizer):
|
|
|
|
|
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False,
|
|
|
|
|
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name,
|
|
|
|
|
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0):
|
|
|
|
|
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")])
|
|
|
|
|
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")], weight_decay, loss_scale)
|
|
|
|
|
if optimizer.is_group:
|
|
|
|
|
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
|
|
|
|
_check_param_value(optimizer, epsilon, hyperpara, use_clip, self.cls_name)
|
|
|
|
|
self.opt = optimizer
|
|
|
|
|
self.parameters = optimizer.parameters
|
|
|
|
|
self.learning_rate = optimizer.learning_rate
|
|
|
|
|