add soft_label and axis for CrossEntropyLoss and improve performance (#29024)

* add soft_label and axis for CrossEntropyLoss and improve performance,test=develop

* fix conflict in nn/functional/loss.py, test=develop
musl/disable_test_yolov3_temporarily
chajchaj 5 years ago committed by GitHub
parent 018e169923
commit b52427327d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -128,6 +128,8 @@ from .loss import binary_cross_entropy #DEFINE_ALIAS
from .loss import binary_cross_entropy_with_logits #DEFINE_ALIAS from .loss import binary_cross_entropy_with_logits #DEFINE_ALIAS
# from .loss import bpr_loss #DEFINE_ALIAS # from .loss import bpr_loss #DEFINE_ALIAS
# from .loss import center_loss #DEFINE_ALIAS # from .loss import center_loss #DEFINE_ALIAS
#from .loss import cross_entropy #DEFINE_ALIAS
from .loss import softmax_cross_entropy #DEFINE_ALIAS
from .loss import cross_entropy #DEFINE_ALIAS from .loss import cross_entropy #DEFINE_ALIAS
from .loss import dice_loss #DEFINE_ALIAS from .loss import dice_loss #DEFINE_ALIAS
from .loss import hsigmoid_loss #DEFINE_ALIAS from .loss import hsigmoid_loss #DEFINE_ALIAS

File diff suppressed because it is too large Load Diff

@ -141,30 +141,40 @@ class BCEWithLogitsLoss(fluid.dygraph.Layer):
class CrossEntropyLoss(fluid.dygraph.Layer): class CrossEntropyLoss(fluid.dygraph.Layer):
r""" """
:alias_main: paddle.nn.CrossEntropyLoss This operator implements the cross entropy loss function with softmax. This function
:alias: paddle.nn.CrossEntropyLoss,paddle.nn.layer.CrossEntropyLoss,paddle.nn.layer.loss.CrossEntropyLoss combines the calculation of the softmax operation and the cross entropy loss function
to provide a more numerically stable gradient.
This operator implements the cross entropy loss function. This OP combines ``LogSoftmax``, Because this operator performs a softmax on logits internally, it expects
and ``NLLLoss`` together. unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
It is useful when training a classification problem with ``C`` classes. When the attribute :attr:`soft_label` is set :attr:`False`, this operators
If provided, the optional argument ``weight`` should be a 1D Variable assigning expects mutually exclusive hard labels, each sample in a batch is in exactly
weight to each of the classes. one class with a probability of 1.0. Each sample in the batch will have a
single label.
For predictions label, and target label, the loss is calculated as follows. The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math:: .. math::
loss_j = -\\text{input[class]} + loss_j = -\\text{logits}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K
If weight is not ``None``: 2) Soft label (each sample can have a distribution over all classes)
.. math:: .. math::
loss_j = \\text{weight[class]}(-\\text{input[class]} + loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K \\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K
It is useful when training a classification problem with ``C`` classes.
Parameters: Parameters:
input (Variable): Input tensor, the data type is float32, float64. Shape is input (Variable): Input tensor, the data type is float32, float64. Shape is
@ -173,9 +183,9 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
label (Variable): Label tensor, the data type is int64. Shape is (N), where each label (Variable): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1. (N, D1, D2,..., Dk), k >= 1.
weight (Variable, optional): Weight tensor, a manual rescaling weight given weight (Variable, optional): Weight tensor, a manual rescaling weight for each
to each class and the shape is (C). It has the same dimensions as class sample relative to each class. It has the same shape as label.
number and the data type is float32, float64. Default is ``'None'``. and the data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size, reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
@ -184,6 +194,12 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
Default is ``'mean'``. Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default is ``-100``. and does not contribute to the input gradient. Default is ``-100``.
soft_label (bool): indicate whether label is soft. Default False, meaning that
the label is hard. If soft_label=True, the label is soft.
axis (int, optional): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns: Returns:
The tensor variable storing the cross_entropy_loss of input and label. The tensor variable storing the cross_entropy_loss of input and label.
@ -192,64 +208,47 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
# declarative mode
import paddle import paddle
import paddle.fluid as fluid
import numpy as np import numpy as np
input_np = np.random.random([2, 4]).astype(np.float64)
input = fluid.data(name='input', shape=[5, 100], dtype='float64') label_np = np.random.randint(0, 4, size=(2, 1)).astype(np.int64)
label = fluid.data(name='label', shape=[5], dtype='int64') weight_np = np.random.random([4]).astype(np.float64) #shape:C
weight = fluid.data(name='weight', shape=[100], dtype='float64') weight_ce = weight_np[label_np] #shape:N,1
ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean') cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
output = ce_loss(input, label) weight=paddle.to_tensor(weight_ce))
place = fluid.CPUPlace() output = cross_entropy_loss(
exe = fluid.Executor(place) paddle.to_tensor(input_np),
exe.run(fluid.default_startup_program()) paddle.to_tensor(label_np))
input_data = np.random.random([5, 100]).astype("float64") print(output.numpy()) #[1.44375251]
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
weight_data = np.random.random([100]).astype("float64")
output = exe.run(fluid.default_main_program(),
feed={"input": input_data, "label": label_data,"weight": weight_data},
fetch_list=[output],
return_numpy=True)
print(output)
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
weight = dg.to_variable(weight_data)
ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean')
output = ce_loss(input, label)
print(output.numpy())
""" """
def __init__(self, weight=None, ignore_index=-100, reduction='mean'): def __init__(self,
weight=None,
ignore_index=-100,
reduction='mean',
soft_label=False,
axis=-1,
name=None):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
self.weight = weight self.weight = weight
self.reduction = reduction self.reduction = reduction
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.soft_label = soft_label
self.axis = axis
self.name = name
def forward(self, input, label): def forward(self, input, label):
fluid.data_feeder.check_variable_and_dtype( ret = paddle.nn.functional.softmax_cross_entropy(
input, 'input', ['float32', 'float64'], 'cross_entropy_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'cross_entropy_loss')
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
" 'none', but received %s, which is not allowed." %
self.reduction)
return paddle.nn.functional.cross_entropy(
input, input,
label, label,
weight=self.weight, weight=self.weight,
ignore_index=self.ignore_index, ignore_index=self.ignore_index,
reduction=self.reduction) reduction=self.reduction,
soft_label=self.soft_label,
axis=self.axis,
name=self.name)
return ret
class HSigmoidLoss(fluid.dygraph.Layer): class HSigmoidLoss(fluid.dygraph.Layer):
@ -491,29 +490,31 @@ class L1Loss(fluid.dygraph.Layer):
If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np
input = paddle.to_tensor([[1.5, 0.8], [0.2, 1.3]]) paddle.disable_static()
label = paddle.to_tensor([[1.7, 1.0], [0.4, 0.5]]) input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
l1_loss = paddle.nn.loss.L1Loss() l1_loss = paddle.nn.loss.L1Loss()
output = l1_loss(input, label) output = l1_loss(input, label)
print(output) print(output.numpy())
# [0.35] # [0.35]
l1_loss = paddle.nn.loss.L1Loss(reduction='sum') l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
output = l1_loss(input, label) output = l1_loss(input, label)
print(output) print(output.numpy())
# [1.4] # [1.4]
l1_loss = paddle.nn.loss.L1Loss(reduction='none') l1_loss = paddle.nn.loss.L1Loss(reduction='none')
output = l1_loss(input, label) output = l1_loss(input, label)
print(output) print(output.numpy())
# [[0.20000005 0.19999999] # [[0.20000005 0.19999999]
# [0.2 0.79999995]] # [0.2 0.79999995]]
""" """
def __init__(self, reduction='mean', name=None): def __init__(self, reduction='mean', name=None):
@ -622,7 +623,9 @@ class BCELoss(fluid.dygraph.Layer):
class NLLLoss(fluid.dygraph.Layer): class NLLLoss(fluid.dygraph.Layer):
r""" """
:alias_main: paddle.nn.NLLLoss
:alias: paddle.nn.NLLLoss,paddle.nn.layer.NLLLoss,paddle.nn.layer.loss.NLLLoss
This class accepts input and target label and returns negative log likelihood This class accepts input and target label and returns negative log likelihood
cross error. It is useful to train a classification problem with C classes. cross error. It is useful to train a classification problem with C classes.
@ -689,7 +692,7 @@ class NLLLoss(fluid.dygraph.Layer):
import paddle import paddle
import numpy as np import numpy as np
nll_loss = paddle.nn.NLLLoss() nll_loss = paddle.nn.layer.NLLLoss()
log_softmax = paddle.nn.LogSoftmax(axis=1) log_softmax = paddle.nn.LogSoftmax(axis=1)
input_np = np.array([[0.88103855, 0.9908683 , 0.6226845 ], input_np = np.array([[0.88103855, 0.9908683 , 0.6226845 ],
@ -699,11 +702,13 @@ class NLLLoss(fluid.dygraph.Layer):
[0.05689114, 0.0862954 , 0.6325046 ]]).astype(np.float32) [0.05689114, 0.0862954 , 0.6325046 ]]).astype(np.float32)
label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64)
place = paddle.CPUPlace()
paddle.disable_static(place)
input = paddle.to_tensor(input_np) input = paddle.to_tensor(input_np)
log_out = log_softmax(input) log_out = log_softmax(input)
label = paddle.to_tensor(label_np) label = paddle.to_tensor(label_np)
result = nll_loss(log_out, label) result = nll_loss(log_out, label)
print(result) # [1.0720209] print(result.numpy()) # [1.0720209]
""" """
@ -999,7 +1004,7 @@ class SmoothL1Loss(fluid.dygraph.Layer):
is the same as the shape of input. is the same as the shape of input.
Returns: Returns:
The tensor storing the smooth_l1_loss of input and label. The tensor variable storing the smooth_l1_loss of input and label.
Return type: Tensor. Return type: Tensor.

Loading…
Cancel
Save