|
|
|
@ -138,14 +138,14 @@ class Optimizer(Cell):
|
|
|
|
|
if self.is_group:
|
|
|
|
|
self.parameters = ParameterTuple(self.group_params)
|
|
|
|
|
self.weight_decay = tuple(self.group_weight_decay)
|
|
|
|
|
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float16) for x in self.group_weight_decay)
|
|
|
|
|
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
|
|
|
|
|
decay_filter = lambda x: x > 0
|
|
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
|
|
|
|
|
self.exec_weight_decay = any(self.decay_flags)
|
|
|
|
|
else:
|
|
|
|
|
self.parameters = ParameterTuple(parameters)
|
|
|
|
|
self.weight_decay = weight_decay * loss_scale
|
|
|
|
|
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float16)
|
|
|
|
|
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
|
|
|
|
|
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
|
|
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
|
|
|
|
self.exec_weight_decay = self.weight_decay > 0
|
|
|
|
@ -156,8 +156,9 @@ class Optimizer(Cell):
|
|
|
|
|
break
|
|
|
|
|
ps_filter = lambda x: x.is_param_ps
|
|
|
|
|
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
|
|
|
|
|
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float16)
|
|
|
|
|
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
|
|
|
|
|
self.need_scale = loss_scale != 1.0
|
|
|
|
|
self.global_step_increase_tensor = Tensor(1, mstype.int32)
|
|
|
|
|
self.param_length = len(self.parameters)
|
|
|
|
|
self.map_ = C.Map()
|
|
|
|
|
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
|
|
|
@ -441,7 +442,7 @@ class Optimizer(Cell):
|
|
|
|
|
else:
|
|
|
|
|
lr = self.learning_rate(self.global_step)
|
|
|
|
|
|
|
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
|
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, self.global_step_increase_tensor))
|
|
|
|
|
return lr
|
|
|
|
|
|
|
|
|
|
def get_lr_parameter(self, param):
|
|
|
|
@ -542,7 +543,7 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
|
|
|
|
|
"""Get grad with weight_decay."""
|
|
|
|
|
if if_apply:
|
|
|
|
|
indices = gradient.indices
|
|
|
|
|
values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values))
|
|
|
|
|
values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
|
|
|
|
|
shape = gradient.dense_shape
|
|
|
|
|
return RowTensor(indices, values, shape)
|
|
|
|
|
return gradient
|
|
|
|
@ -552,7 +553,7 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
|
|
|
|
|
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
|
|
|
|
"""Get grad with weight_decay."""
|
|
|
|
|
if if_apply:
|
|
|
|
|
return op_add((op_mul(weight, weight_decay), gradient))
|
|
|
|
|
return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
|
|
|
|
|
return gradient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -564,17 +565,17 @@ def tensor_grad_scale(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|
if scale == 1.0:
|
|
|
|
|
return grad
|
|
|
|
|
return op_mul(grad, scale)
|
|
|
|
|
return op_mul(grad, F.cast(scale, F.dtype(grad)))
|
|
|
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "Tensor")
|
|
|
|
|
def tensor_grad_scale_with_tensor(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|
return op_mul(grad, scale)
|
|
|
|
|
return op_mul(grad, F.cast(scale, F.dtype(grad)))
|
|
|
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "RowTensor")
|
|
|
|
|
def tensor_grad_scale_with_sparse(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)
|
|
|
|
|
return RowTensor(grad.indices, grad.values * F.cast(scale, F.dtype(grad.values)), grad.dense_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_indices_deduplicate.register("RowTensor")
|
|
|
|
|