teacher student sigmoid loss

revert-15207-remove_op_handle_lock_and_fix_var
heqiaozhi 6 years ago
parent a79a3ea2f0
commit 4f6e9e3ac3

File diff suppressed because it is too large Load Diff

@ -0,0 +1,25 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
} // namespace operators
} // namespace paddle

@ -176,6 +176,7 @@ __all__ = [
'get_tensor_from_selected_rows',
'lstm',
'psroi_pool',
'teacher_student_sigmoid_loss',
]
kIgnoreIndex = -100
@ -9184,6 +9185,47 @@ def log_loss(input, label, epsilon=1e-4, name=None):
return loss
def teacher_student_sigmoid_loss(input,
label,
soft_max_up_bound=15.0,
soft_max_lower_bound=-15.0):
"""
**Teacher Student Log Loss Layer**
This layer accepts input predictions and target label and returns the
teacher_student loss.
.. math::
loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x)))
Args:
input (Variable|list): a 2-D tensor with shape [N x 1], where N is the
batch size. This input is a probability computed
by the previous operator.
label (Variable|list): the ground truth which is a 2-D tensor with
shape [N x 1], where N is the batch size.
soft_max_up_bound (float): if input > soft_max_up_bound, will be bound
soft_max_lower_bound (float): if input < soft_max_lower_bound, will be bound
Returns:
Variable: A 2-D tensor with shape [N x 1], the teacher_student_sigmoid_loss.
Examples:
.. code-block:: python
cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label)
"""
helper = LayerHelper('teacher_student_sigmoid_loss', **locals())
out = helper.create_variable(dtype=input.dtype)
helper.append_op(
type='teacher_student_sigmoid_loss',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs={"soft_max_lower_bound": float(soft_max_lower_bound), \
"soft_max_up_bound": float(soft_max_up_bound)})
return out
def add_position_encoding(input, alpha, beta, name=None):
"""
**Add Position Encoding Layer**

@ -0,0 +1,70 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from math import log
from math import exp
from op_test import OpTest
from scipy.special import logit
from scipy.special import expit
import unittest
class TestTeacherStudentSigmoidLossOp(OpTest):
"""
Test teacher_student_sigmoid_loss with discrete one-hot labels.
"""
def setUp(self):
"""
ut
"""
self.op_type = "teacher_student_sigmoid_loss"
batch_size = 16
num_classes = 1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, (batch_size, num_classes))
.astype("float32")),
'Label': np.random.uniform(0, 2, (batch_size, num_classes))
.astype("float32")
}
outs = []
for index, label in enumerate(self.inputs["Label"]):
x = self.inputs["X"][index]
if label < -1.0:
outs.append(max(x, 0.0) + log(1.0 + exp(-abs(x))))
elif label < 0.0:
outs.append(max(x, 0.0) - x + log(1.0 + exp(-abs(x))))
elif label < 1.0:
outs.append(max(x, 0.0) + log(1.0 + exp(-abs(x))) + \
max(x, 0.0) - x * label + log(1.0 + exp(-abs(x))))
#print "33 python x:", x, "python label:", label, "term1:", max(x, 0.0) + log(1.0 + exp(-abs(x))), "term2:", max(x, 0.0) - x * label + log(1.0 + exp(-abs(x)))
else:
outs.append(max(x, 0.0) - x + log(1.0 + exp(-abs(x))) + \
max(x, 0.0) - x * (label - 1.0) + log(1.0 + exp(-abs(x))))
#print "44 python x:", x, "python label:", label, "term1:", max(x, 0.0) - x + log(1.0 + exp(-abs(x))), "term2:", max(x, 0.0) - x * (label - 1.0) + log(1.0 + exp(-abs(x)))
self.outputs = {'Y': np.array(outs)}
def test_check_output(self):
"""
ut
"""
self.check_output()
def test_check_grad(self):
"""
ut
"""
self.check_grad(["X"], "Y", numeric_grad_delta=0.005)
Loading…
Cancel
Save