|
|
@ -257,6 +257,7 @@ class Optimizer(Cell):
|
|
|
|
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
|
|
|
|
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
|
|
|
|
|
|
|
|
|
|
|
|
for param in group_param['params']:
|
|
|
|
for param in group_param['params']:
|
|
|
|
|
|
|
|
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
|
|
|
|
if param in params_store:
|
|
|
|
if param in params_store:
|
|
|
|
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
|
|
|
|
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
|
|
|
|
params_store.append(param)
|
|
|
|
params_store.append(param)
|
|
|
@ -286,6 +287,36 @@ class Optimizer(Cell):
|
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
|
|
|
return lr
|
|
|
|
return lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lr_parameter(self, param):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get the learning rate of parameter.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Parameter, single `Parameter` or `list[Parameter]` according to the input type.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not isinstance(param, (Parameter, list)):
|
|
|
|
|
|
|
|
raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(param, list):
|
|
|
|
|
|
|
|
lr = []
|
|
|
|
|
|
|
|
for p in param:
|
|
|
|
|
|
|
|
validator.check_value_type("parameter", p, [Parameter], self.cls_name)
|
|
|
|
|
|
|
|
if self.is_group_lr:
|
|
|
|
|
|
|
|
index = self.parameters.index(p)
|
|
|
|
|
|
|
|
lr.append(self.learning_rate[index])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
lr.append(self.learning_rate)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
if self.is_group_lr:
|
|
|
|
|
|
|
|
index = self.parameters.index(param)
|
|
|
|
|
|
|
|
lr = self.learning_rate[index]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
lr = self.learning_rate
|
|
|
|
|
|
|
|
return lr
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, *hyper_params):
|
|
|
|
def construct(self, *hyper_params):
|
|
|
|
raise NotImplementedError
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|