diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 9a3de36f47..737ae78ec1 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -18,6 +18,8 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.nn.cell import Cell +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel from ... import context @@ -215,6 +217,8 @@ class SoftmaxCrossEntropyWithLogits(_Loss): sparse (bool): Specifies whether labels use sparse format or not. Default: False. reduction (Union[str, None]): Type of reduction to apply to loss. Support 'sum' or 'mean' If None, do not reduction. Default: None. + smooth_factor (float): Label smoothing factor. It is a optional input. Default: 0. + num_classes (int): The number of classes in the task. It is a optional input Default: 2. Inputs: - **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`. @@ -235,14 +239,20 @@ class SoftmaxCrossEntropyWithLogits(_Loss): def __init__(self, is_grad=True, sparse=False, - reduction=None): + reduction=None, + smooth_factor=0, + num_classes=2): super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) self.is_grad = is_grad self.sparse = sparse + validator.check_integer("num_classes", num_classes, 1, Rel.GT, self.cls_name) + validator.check_number_range("smooth_factor", smooth_factor, 0, 1, Rel.INC_BOTH, self.cls_name) + self.smooth_factor = smooth_factor + self.num_classes = num_classes self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) + self.on_value = Tensor(1.0 - self.smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * self.smooth_factor / (self.num_classes - 1), mstype.float32) self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] if self.is_cpugpu: diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 166e8ae296..080377b71d 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -17,6 +17,7 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype +from mindspore._checkparam import check_bool from .optimizer import Optimizer momentum_opt = C.MultitypeFuncGraph("momentum_opt") @@ -67,6 +68,7 @@ class Momentum(Optimizer): momentum (float): Hyperparameter of type float, means momentum for the moving average. weight_decay (float): Weight decay (L2 penalty). Default: 0.0. loss_scale (float): A floating point value for the loss scale. Default: 1.0. + use_nesterov (bool): Enable Nesterov momentum. Default: False. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -95,15 +97,16 @@ class Momentum(Optimizer): >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) """ - def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0): + def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters + self.use_nesterov = check_bool(use_nesterov) self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() - self.opt = P.ApplyMomentum() + self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) def construct(self, gradients): params = self.params diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 04b9f49c7c..d338316c61 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1757,8 +1757,8 @@ class LayerNorm(Primitive): - **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`. The shape is :math:`(N, C)`. - - **updated_gamma** (Tensor) - Tensor of shape :math:`(C,)`. - - **updated_beta** (Tensor) - Tensor of shape :math:`(C,)`. + - **mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **variance** (Tensor) - Tensor of shape :math:`(C,)`. Examples: >>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)