|
|
|
@ -138,12 +138,14 @@ class Optimizer(Cell):
|
|
|
|
|
if self.is_group:
|
|
|
|
|
self.parameters = ParameterTuple(self.group_params)
|
|
|
|
|
self.weight_decay = tuple(self.group_weight_decay)
|
|
|
|
|
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
|
|
|
|
|
decay_filter = lambda x: x > 0
|
|
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
|
|
|
|
|
self.exec_weight_decay = any(self.decay_flags)
|
|
|
|
|
else:
|
|
|
|
|
self.parameters = ParameterTuple(parameters)
|
|
|
|
|
self.weight_decay = weight_decay * loss_scale
|
|
|
|
|
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
|
|
|
|
|
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
|
|
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
|
|
|
|
self.exec_weight_decay = self.weight_decay > 0
|
|
|
|
@ -154,7 +156,8 @@ class Optimizer(Cell):
|
|
|
|
|
break
|
|
|
|
|
ps_filter = lambda x: x.is_param_ps
|
|
|
|
|
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
|
|
|
|
|
self.reciprocal_scale = 1.0 / loss_scale
|
|
|
|
|
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
|
|
|
|
|
self.need_scale = loss_scale != 1.0
|
|
|
|
|
self.param_length = len(self.parameters)
|
|
|
|
|
self.map_ = C.Map()
|
|
|
|
|
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
|
|
|
@ -222,10 +225,10 @@ class Optimizer(Cell):
|
|
|
|
|
if self.exec_weight_decay:
|
|
|
|
|
params = self.parameters
|
|
|
|
|
if self.is_group:
|
|
|
|
|
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
|
|
|
|
|
gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
|
|
|
|
|
params, gradients)
|
|
|
|
|
else:
|
|
|
|
|
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
|
|
|
|
|
gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
|
|
|
|
|
params, gradients)
|
|
|
|
|
|
|
|
|
|
return gradients
|
|
|
|
@ -245,7 +248,7 @@ class Optimizer(Cell):
|
|
|
|
|
tuple[Tensor], The gradients after loss scale.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if self.reciprocal_scale != 1.0:
|
|
|
|
|
if self.need_scale:
|
|
|
|
|
gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
|
|
|
|
|
|
|
|
|
|
return gradients
|
|
|
|
@ -529,11 +532,12 @@ class Optimizer(Cell):
|
|
|
|
|
|
|
|
|
|
op_add = P.AddN()
|
|
|
|
|
op_gather = P.GatherV2()
|
|
|
|
|
op_mul = P.Mul()
|
|
|
|
|
|
|
|
|
|
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor")
|
|
|
|
|
@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
|
|
|
|
|
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
|
|
|
|
|
"""Get grad with weight_decay."""
|
|
|
|
|
if if_apply:
|
|
|
|
@ -544,11 +548,11 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
|
|
|
|
|
return gradient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
|
|
|
|
@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
|
|
|
|
|
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
|
|
|
|
"""Get grad with weight_decay."""
|
|
|
|
|
if if_apply:
|
|
|
|
|
return op_add((weight * weight_decay, gradient))
|
|
|
|
|
return op_add((op_mul(weight, weight_decay), gradient))
|
|
|
|
|
return gradient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -560,14 +564,16 @@ def tensor_grad_scale(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|
if scale == 1.0:
|
|
|
|
|
return grad
|
|
|
|
|
return grad * scale
|
|
|
|
|
return op_mul(grad, scale)
|
|
|
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "Tensor")
|
|
|
|
|
def tensor_grad_scale_with_tensor(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|
return op_mul(grad, scale)
|
|
|
|
|
|
|
|
|
|
@_grad_scale.register("Number", "RowTensor")
|
|
|
|
|
@_grad_scale.register("Tensor", "RowTensor")
|
|
|
|
|
def tensor_grad_scale_with_sparse(scale, grad):
|
|
|
|
|
"""Get grad with scale."""
|
|
|
|
|
if scale == 1.0:
|
|
|
|
|
return grad
|
|
|
|
|
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|