|
|
|
@ -19,7 +19,7 @@ __all__ = [
|
|
|
|
|
'CrossEntropyLoss',
|
|
|
|
|
# 'MSELoss',
|
|
|
|
|
'L1Loss',
|
|
|
|
|
# 'NLLLoss',
|
|
|
|
|
'NLLLoss',
|
|
|
|
|
'BCELoss'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
@ -329,3 +329,145 @@ class BCELoss(fluid.dygraph.Layer):
|
|
|
|
|
return fluid.layers.reduce_mean(out)
|
|
|
|
|
else:
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NLLLoss(fluid.dygraph.Layer):
|
|
|
|
|
"""
|
|
|
|
|
This op accepts input and target label and returns negative log likelihood
|
|
|
|
|
cross error. It is useful to train a classification problem with C classes.
|
|
|
|
|
|
|
|
|
|
The input for the loss is epected to contain log-probabilities of
|
|
|
|
|
each classes. It hs to be a Tensor of size either (batch_size, C) or
|
|
|
|
|
(batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
|
|
|
|
|
The label for the loss should be a class index in the range [0, C-1]
|
|
|
|
|
where C is the number of classes. If ignore_index is specified, the
|
|
|
|
|
specified target value does not contribute to the input gradient.
|
|
|
|
|
|
|
|
|
|
If the optional argument `weight` is provided, it should be a 1D Tensor
|
|
|
|
|
assigning weight to each of the classed. This is particularly useful
|
|
|
|
|
when you have an unbalanced training set.
|
|
|
|
|
|
|
|
|
|
The loss is calculated as follows.
|
|
|
|
|
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
\ell(x, y) = L = \{l_1,\dots,l_N\}^\\top, \quad
|
|
|
|
|
l_n = - w_{y_n} x_{n,y_n}, \quad
|
|
|
|
|
w_{c} = \\text{weight}[c] \cdot \mathbb{1}\{c \\not= \\text{ignore\\_index}\},
|
|
|
|
|
|
|
|
|
|
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
|
|
|
|
|
(default ``'mean'``), then
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
\ell(x, y) = \\begin{cases}
|
|
|
|
|
\\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n}} l_n, &
|
|
|
|
|
\\text{if reduction} = \\text{'mean';}\\\\
|
|
|
|
|
\\sum_{n=1}^N l_n, &
|
|
|
|
|
\\text{if reduction} = \\text{'sum'.}
|
|
|
|
|
\\end{cases}
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
input (Variable): Input tensor, the data type is float32, float64.
|
|
|
|
|
label (Variable): Label tensor, the data type is int64_t.
|
|
|
|
|
weight (Variable, optional): Weight tensor, a manual rescaling weight given
|
|
|
|
|
to each class. If given, it has to be a Tensor of size `C`. Otherwise,
|
|
|
|
|
it treated as if having all ones. the data type is
|
|
|
|
|
float32, float64, Default is ``'None'``.
|
|
|
|
|
reduction (str, optional): Indicate how to average the loss,
|
|
|
|
|
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
|
|
|
|
|
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
|
|
|
|
|
Default is ``'mean'``.
|
|
|
|
|
ignore_index (int64, optional): Specifies a target value that is ignored
|
|
|
|
|
and does not contribute to the input gradient.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The tensor variable storing the nll_loss.
|
|
|
|
|
|
|
|
|
|
Return type: Variable.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
# declarative mode
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
input_np = np.random.random(size=(10, 10)).astype(np.float32)
|
|
|
|
|
label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64)
|
|
|
|
|
prog = fluid.Program()
|
|
|
|
|
startup_prog = fluid.Program()
|
|
|
|
|
place = fluid.CPUPlace()
|
|
|
|
|
with fluid.program_guard(prog, startup_prog):
|
|
|
|
|
input = fluid.data(name='input', shape=[10, 10], dtype='float32')
|
|
|
|
|
label = fluid.data(name='label', shape=[10], dtype='int64')
|
|
|
|
|
nll_loss = paddle.nn.loss.NLLLoss()
|
|
|
|
|
res = nll_loss(input, label)
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
static_result = exe.run(
|
|
|
|
|
prog,
|
|
|
|
|
feed={"input": input_np,
|
|
|
|
|
"label": label_np},
|
|
|
|
|
fetch_list=[res])
|
|
|
|
|
print(static_result)
|
|
|
|
|
|
|
|
|
|
# imperative mode
|
|
|
|
|
import paddle.fluid.dygraph as dg
|
|
|
|
|
with dg.guard(place) as g:
|
|
|
|
|
input = dg.to_variable(input_np)
|
|
|
|
|
label = dg.to_variable(label_np)
|
|
|
|
|
output = nll_loss(input, label)
|
|
|
|
|
print(output.numpy())
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, weight=None, reduction='mean', ignore_index=-100):
|
|
|
|
|
super(NLLLoss, self).__init__()
|
|
|
|
|
self.weight = weight
|
|
|
|
|
self.reduction = reduction
|
|
|
|
|
self.ignore_index = ignore_index
|
|
|
|
|
|
|
|
|
|
def forward(self, input, label):
|
|
|
|
|
dtype = self._helper.input_dtype(input)
|
|
|
|
|
|
|
|
|
|
fluid.data_feeder.check_variable_and_dtype(
|
|
|
|
|
input, 'input', ['float32', 'float64'], 'nll_loss')
|
|
|
|
|
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
|
|
|
|
|
'nll_loss')
|
|
|
|
|
|
|
|
|
|
if self.reduction not in ['sum', 'mean', 'none']:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The value of 'reduction' in nll_loss should be 'sum', 'mean' or 'none', but "
|
|
|
|
|
"received %s, which is not allowed." % self.reduction)
|
|
|
|
|
|
|
|
|
|
x_shape = list(input.shape)
|
|
|
|
|
n = x_shape[0]
|
|
|
|
|
c = x_shape[1]
|
|
|
|
|
x_dims = len(x_shape)
|
|
|
|
|
if x_dims < 2:
|
|
|
|
|
raise ValueError('Expected 2 or more dimensions (got {})'.format(
|
|
|
|
|
x_dims))
|
|
|
|
|
if x_dims != 2 and x_dims != 4:
|
|
|
|
|
input = fluid.layers.reshape(input, shape=[n, c, 1, -1])
|
|
|
|
|
label = fluid.layers.reshape(label, shape=[n, 1, -1])
|
|
|
|
|
out_shape = [n] + x_shape[2:]
|
|
|
|
|
|
|
|
|
|
inputs = {'X': input, 'Label': label}
|
|
|
|
|
attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index}
|
|
|
|
|
|
|
|
|
|
if self.weight is not None:
|
|
|
|
|
if isinstance(self.weight, fluid.framework.Variable):
|
|
|
|
|
inputs['Weight'] = self.weight
|
|
|
|
|
|
|
|
|
|
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
|
|
|
|
|
total_weight = self._helper.create_variable_for_type_inference(
|
|
|
|
|
dtype=input.dtype)
|
|
|
|
|
outputs = {'Out': out, 'Total_weight': total_weight}
|
|
|
|
|
|
|
|
|
|
self._helper.append_op(
|
|
|
|
|
type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs)
|
|
|
|
|
if x_dims != 2 and x_dims != 4 and self.reduction == 'none':
|
|
|
|
|
out = fluid.layers.reshape(out, shape=out_shape)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|