|
|
|
@ -276,17 +276,17 @@ class Optimizer(Cell):
|
|
|
|
|
learning_rate = float(learning_rate)
|
|
|
|
|
validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
|
|
|
|
|
return learning_rate
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
|
|
|
|
|
return learning_rate
|
|
|
|
|
|
|
|
|
|
self.dynamic_lr = True
|
|
|
|
|
if isinstance(learning_rate, Iterable):
|
|
|
|
|
return Tensor(np.array(list(learning_rate)).astype(np.float32))
|
|
|
|
|
if isinstance(learning_rate, Tensor):
|
|
|
|
|
if learning_rate.dim() > 1:
|
|
|
|
|
if learning_rate.ndim > 1:
|
|
|
|
|
raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
|
|
|
|
|
f"but got {learning_rate.dim()}.")
|
|
|
|
|
if learning_rate.dim() == 1 and learning_rate.size() < 2:
|
|
|
|
|
f"but got {learning_rate.ndim}.")
|
|
|
|
|
if learning_rate.ndim == 1 and learning_rate.size < 2:
|
|
|
|
|
logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
|
|
|
|
|
"of elements in the tensor passed is greater than 1.")
|
|
|
|
|
return learning_rate
|
|
|
|
@ -301,12 +301,12 @@ class Optimizer(Cell):
|
|
|
|
|
if self.is_group_lr and self.dynamic_lr:
|
|
|
|
|
learning_rate = _ConvertToCell(learning_rate)
|
|
|
|
|
return learning_rate
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
|
|
|
|
|
learning_rate = Parameter(learning_rate, name)
|
|
|
|
|
if self.is_group_lr and self.dynamic_lr:
|
|
|
|
|
learning_rate = _ConvertToCell(learning_rate)
|
|
|
|
|
return learning_rate
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
|
|
|
|
|
return _IteratorLearningRate(learning_rate, name)
|
|
|
|
|
return learning_rate
|
|
|
|
|
|
|
|
|
@ -336,8 +336,8 @@ class Optimizer(Cell):
|
|
|
|
|
def _parse_group_params(self, parameters, learning_rate):
|
|
|
|
|
"""Parse group params."""
|
|
|
|
|
self._check_group_params(parameters)
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
|
|
|
|
|
tensor_lr_length = learning_rate.size()
|
|
|
|
|
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
|
|
|
|
|
tensor_lr_length = learning_rate.size
|
|
|
|
|
else:
|
|
|
|
|
tensor_lr_length = 0
|
|
|
|
|
|
|
|
|
@ -355,8 +355,8 @@ class Optimizer(Cell):
|
|
|
|
|
self.is_group_lr = True
|
|
|
|
|
group_lr = self._preprocess_single_lr(group_param['lr'])
|
|
|
|
|
|
|
|
|
|
if isinstance(group_lr, Tensor) and group_lr.dim() == 1:
|
|
|
|
|
group_lr_length = group_lr.size()
|
|
|
|
|
if isinstance(group_lr, Tensor) and group_lr.ndim == 1:
|
|
|
|
|
group_lr_length = group_lr.size
|
|
|
|
|
if tensor_lr_length == 0:
|
|
|
|
|
tensor_lr_length = group_lr_length
|
|
|
|
|
elif group_lr_length != tensor_lr_length:
|
|
|
|
@ -615,9 +615,9 @@ class _IteratorLearningRate(LearningRateSchedule):
|
|
|
|
|
def __init__(self, learning_rate, name):
|
|
|
|
|
super(_IteratorLearningRate, self).__init__()
|
|
|
|
|
if isinstance(learning_rate, Tensor):
|
|
|
|
|
if learning_rate.dim() != 1:
|
|
|
|
|
if learning_rate.ndim != 1:
|
|
|
|
|
raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1,"
|
|
|
|
|
f"but got {learning_rate.dim()}.")
|
|
|
|
|
f"but got {learning_rate.ndim}.")
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError("Learning rate should be Tensor.")
|
|
|
|
|
|
|
|
|
|