add binary cross entropy with logit loss (#26468)

* add binary cross entropy with logit loss
test_feature_precision_test_c
Zhong Hui 5 years ago committed by GitHub
parent 4e0c6d91aa
commit f5d1349826
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -107,6 +107,7 @@ from .layer.extension import RowConv #DEFINE_ALIAS
# from .layer.learning_rate import PiecewiseDecay #DEFINE_ALIAS
# from .layer.learning_rate import PolynomialDecay #DEFINE_ALIAS
# from .layer.loss import NCELoss #DEFINE_ALIAS
from .layer.loss import BCEWithLogitsLoss #DEFINE_ALIAS
from .layer.loss import CrossEntropyLoss #DEFINE_ALIAS
from .layer.loss import MSELoss #DEFINE_ALIAS
from .layer.loss import L1Loss #DEFINE_ALIAS

@ -126,6 +126,7 @@ from .lod import hash #DEFINE_ALIAS
# from .lod import dynamic_lstm #DEFINE_ALIAS
# from .lod import dynamic_lstmp #DEFINE_ALIAS
from .loss import binary_cross_entropy #DEFINE_ALIAS
from .loss import binary_cross_entropy_with_logits #DEFINE_ALIAS
from .loss import bpr_loss #DEFINE_ALIAS
from .loss import center_loss #DEFINE_ALIAS
from .loss import cross_entropy #DEFINE_ALIAS

@ -49,6 +49,7 @@ from ...fluid.framework import Variable
__all__ = [
'binary_cross_entropy',
'binary_cross_entropy_with_logits',
'bpr_loss',
'center_loss',
'cross_entropy',
@ -214,6 +215,154 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean',
return out
def binary_cross_entropy_with_logits(logit,
label,
weight=None,
reduction='mean',
pos_weight=None,
name=None):
"""
This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
layer and some reduce operations.
This measures the element-wise probability error in classification tasks
in which each class is independent.
This can be thought of as predicting labels for a data-point, where labels
are not mutually exclusive. For example, a news article can be about
politics, technology or sports at the same time or none of these.
First this operator calculate loss function as follows:
.. math::
Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit))
We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get:
.. math::
Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit})
For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0,
we reformulate the loss as follows:
.. math::
Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|})
Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
weight tensor on the loss `Out`. The ``weight`` tensor will attach different
weight on every items in the batch. The ``pos_weight`` will attach different
weight on the positive label of each class.
Finally, this operator applies reduce operation on the loss.
If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.
Note that the target labels ``label`` should be numbers between 0 and 1.
Args:
logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``logit``
is usually the output of Linear layer. Available dtype is float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``logit``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
The data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
with length equal to the number of classes. The data type is float32, float64.
Default is ``'None'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``logit`` , else the shape of output is scalar.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
output = paddle.nn.functional.binary_cross_entropy_with_logits(logit, label)
print(output.numpy()) # [0.45618808]
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in binary_cross_entropy_with_logits "
"should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
% reduction)
if in_dygraph_mode():
one = _varbase_creator(dtype=logit.dtype)
core.ops.fill_constant(one, 'value',
float(1.0), 'force_cpu', False, 'dtype',
one.dtype, 'str_value', '1.0', 'shape', [1])
out = core.ops.sigmoid_cross_entropy_with_logits(logit, label)
if pos_weight is not None:
log_weight = core.ops.elementwise_add(
core.ops.elementwise_mul(
label, core.ops.elementwise_sub(pos_weight, one)), one)
out = core.ops.elementwise_mul(out, log_weight)
if weight is not None:
out = core.ops.elementwise_mul(out, weight)
if reduction == "sum":
return core.ops.reduce_sum(out, 'reduce_all', True)
elif reduction == "mean":
return core.ops.mean(out)
else:
return out
fluid.data_feeder.check_variable_and_dtype(
logit, 'logit', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
sigmoid_name = None
if reduction == 'none' and pos_weight is None and weight is None:
sigmoid_name = name
out = paddle.nn.functional.sigmoid_cross_entropy_with_logits(
logit, label, name=sigmoid_name)
one = paddle.fill_constant(shape=[1], value=1.0, dtype=logit.dtype)
if pos_weight is not None:
fluid.data_feeder.check_variable_and_dtype(
pos_weight, 'pos_weight', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
log_weight = paddle.add(
paddle.multiply(label, paddle.elementwise_sub(pos_weight, one)),
one)
pos_weight_name = name if reduction == 'none' and weight is None else None
out = paddle.multiply(out, log_weight, name=pos_weight_name)
if weight is not None:
fluid.data_feeder.check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
weight_name = name if reduction == 'none' else None
out = paddle.multiply(out, weight, name=weight_name)
if reduction == "sum":
return paddle.sum(out, name=name)
elif reduction == "mean":
return paddle.mean(out, name=name)
return out
def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
"""
This operator calculates smooth_l1_loss. Creates a criterion that uses a squared

@ -72,6 +72,7 @@ from .extension import RowConv #DEFINE_ALIAS
# from .learning_rate import PiecewiseDecay #DEFINE_ALIAS
# from .learning_rate import PolynomialDecay #DEFINE_ALIAS
# from .loss import NCELoss #DEFINE_ALIAS
from .loss import BCEWithLogitsLoss #DEFINE_ALIAS
from .loss import CrossEntropyLoss #DEFINE_ALIAS
from .loss import MSELoss #DEFINE_ALIAS
from .loss import L1Loss #DEFINE_ALIAS

@ -21,6 +21,7 @@ from .. import functional as F
from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator
__all__ = [
'BCEWithLogitsLoss',
'CrossEntropyLoss',
'MSELoss',
'L1Loss',
@ -33,6 +34,111 @@ __all__ = [
]
class BCEWithLogitsLoss(fluid.dygraph.Layer):
"""
This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
layer and some reduce operations.
This measures the element-wise probability error in classification tasks
in which each class is independent.
This can be thought of as predicting labels for a data-point, where labels
are not mutually exclusive. For example, a news article can be about
politics, technology or sports at the same time or none of these.
First this operator calculate loss function as follows:
.. math::
Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit))
We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get:
.. math::
Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit})
For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0,
we reformulate the loss as follows:
.. math::
Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|})
Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
weight tensor on the loss `Out`. The ``weight`` tensor will attach different
weight on every items in the batch. The ``pos_weight`` will attach different
weight on the positive label of each class.
Finally, this operator applies reduce operation on the loss.
If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.
Note that the target labels ``label`` should be numbers between 0 and 1.
Args:
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
The data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
with length equal to the number of classes. The data type is float32, float64.
Default is ``'None'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shapes:
logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``logit``
is usually the output of Linear layer. Available dtype is float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``logit``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``logit`` , else the shape of output is scalar.
Returns:
A callable object of BCEWithLogitsLoss.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
bce_logit_loss = paddle.nn.BCEWithLogitsLoss()
output = bce_logit_loss(logit, label)
print(output.numpy()) # [0.45618808]
"""
def __init__(self,
weight=None,
reduction='mean',
pos_weight=None,
name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
super(BCEWithLogitsLoss, self).__init__()
self.weight = weight
self.reduction = reduction
self.pos_weight = pos_weight
self.name = name
def forward(self, logit, label):
out = paddle.nn.functional.binary_cross_entropy_with_logits(
logit, label, self.weight, self.reduction, self.pos_weight,
self.name)
return out
class CrossEntropyLoss(fluid.dygraph.Layer):
"""
:alias_main: paddle.nn.CrossEntropyLoss
@ -678,9 +784,9 @@ class CTCLoss(fluid.dygraph.Layer):
:alias_main: paddle.nn.CTCLoss
:alias: paddle.nn.CTCLoss, paddle.nn.layer.CTCLoss, paddle.nn.layer.loss.CTCLoss
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
is interated to the Warp-CTC library to normalize values for each row of the input tensor.
Parameters:
@ -695,7 +801,7 @@ class CTCLoss(fluid.dygraph.Layer):
Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
Examples:
.. code-block:: python
@ -739,13 +845,13 @@ class CTCLoss(fluid.dygraph.Layer):
input_lengths = paddle.to_variable(input_lengths)
label_lengths = paddle.to_variable(label_lengths)
loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
input_lengths,
loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
input_lengths,
label_lengths)
print(loss.numpy()) #[3.9179852 2.9076521]
loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
input_lengths,
loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
input_lengths,
label_lengths)
print(loss.numpy()) #[1.1376063]
"""

Loading…
Cancel
Save