commit
c4c5f0b8ca
@ -0,0 +1,145 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/bpr_loss_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class BprLossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto label_dims = ctx->GetInputDim("Label");
|
||||
int rank = x_dims.size();
|
||||
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
|
||||
"Input(X) and Input(Label) shall have the same rank.");
|
||||
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
||||
framework::slice_ddim(label_dims, 0, rank - 1),
|
||||
"Input(X) and Input(Label) shall have the same shape "
|
||||
"except the last dimension.");
|
||||
|
||||
auto y_dims = x_dims;
|
||||
y_dims[rank - 1] = 1;
|
||||
ctx->SetOutputDim("Y", y_dims);
|
||||
ctx->ShareLoD("X", /*->*/ "Y");
|
||||
}
|
||||
|
||||
protected:
|
||||
// Explicitly set that the data type of computation kernel of Seq-bpr
|
||||
// is determined by its input "X".
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
|
||||
platform::CPUPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class BprLossGradientOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
||||
"Input(Y@GRAD) shoudl be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Output(X@GRAD) should be not null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto label_dims = ctx->GetInputDim("Label");
|
||||
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
|
||||
int rank = x_dims.size();
|
||||
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
|
||||
"Input(Y@Grad) and Input(X) should have the same rank.");
|
||||
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
|
||||
"Input(Label) and Input(X) should have the same rank.");
|
||||
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
||||
framework::slice_ddim(label_dims, 0, rank - 1),
|
||||
"The Input(X) and Input(Label) should have the same "
|
||||
"shape except the last dimension.");
|
||||
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
||||
framework::slice_ddim(dy_dims, 0, rank - 1),
|
||||
"The Input(X) and Input(Y@Grad) should have the same "
|
||||
"shape except the last dimension.");
|
||||
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
|
||||
"The last dimension of Input(Y@Grad) should be 1.");
|
||||
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
|
||||
" the last dimension of Input(Label) should be 1.");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
||||
ctx->ShareLoD("X", framework::GradVarName("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
// Explicitly set that the data type of computation kernel of cross_entropy
|
||||
// is determined by its input "X".
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
|
||||
platform::CPUPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class BprLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor, default Tensor<float>), a tensor whose last dimension "
|
||||
"size is equal to the number of classes. This input is a "
|
||||
"real number.");
|
||||
AddInput(
|
||||
"Label",
|
||||
"(Tensor), the tensor which represents the ground truth. It has the "
|
||||
"same shape with 'X' except the last dimension. the last dimension "
|
||||
"size is 1.");
|
||||
AddOutput("Y",
|
||||
"(Tensor, default Tensor<float>), a tensor whose shape is same "
|
||||
"with 'X' except that the last dimension size is 1. It "
|
||||
"represents the sequence bpr loss.");
|
||||
AddComment(R"DOC(
|
||||
Bayesian Personalized Ranking Loss Operator.
|
||||
|
||||
This operator belongs to pairwise ranking loss. Label is the desired item.
|
||||
The loss at a given point in one session is defined as:
|
||||
$Y[i] = -\frac{1}{N_{i}} * \sum_{j=0}^{N_{i}}\log(\sigma(X[i, Label[i]]-X[i, j]))$
|
||||
|
||||
Learn more details by reading paper <session-based recommendations with recurrent
|
||||
neural networks>(https://arxiv.org/abs/1511.06939)
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPUCtx = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp);
|
||||
REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel<CPUCtx, float>,
|
||||
ops::BprLossOpKernel<CPUCtx, double>);
|
||||
REGISTER_OP_CPU_KERNEL(bpr_loss_grad,
|
||||
ops::BprLossGradientOpKernel<CPUCtx, float>,
|
||||
ops::BprLossGradientOpKernel<CPUCtx, double>);
|
@ -0,0 +1,118 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
/*Todo:
|
||||
*Find a way to adapt TolerableValue, using blas or eigen.
|
||||
*/
|
||||
template <typename T>
|
||||
struct TolerableValue {
|
||||
HOSTDEVICE T operator()(const T& x) const {
|
||||
PADDLE_ASSERT(std::is_floating_point<T>::value);
|
||||
const T kApproInf = 1e20;
|
||||
if (x == INFINITY) return kApproInf;
|
||||
if (x == -INFINITY) return -kApproInf;
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BprLossOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* label = ctx.Input<Tensor>("Label");
|
||||
auto* y = ctx.Output<Tensor>("Y");
|
||||
y->mutable_data<T>(ctx.GetPlace());
|
||||
int rank = x->dims().size();
|
||||
|
||||
Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
|
||||
Tensor labels_2d = framework::ReshapeToMatrix(*label, rank - 1);
|
||||
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
|
||||
|
||||
const framework::Tensor* logits = &x_2d;
|
||||
const framework::Tensor* labels = &labels_2d;
|
||||
framework::Tensor* out = &y_2d;
|
||||
|
||||
const int step_size = logits->dims()[0];
|
||||
const int class_num = logits->dims()[1];
|
||||
const T* logits_data = logits->data<T>();
|
||||
T* loss_data = out->data<T>();
|
||||
|
||||
const int64_t* label_data = labels->data<int64_t>();
|
||||
for (int i = 0; i < step_size; ++i) {
|
||||
int lbl_pos = label_data[i];
|
||||
PADDLE_ENFORCE_GE(lbl_pos, 0);
|
||||
PADDLE_ENFORCE_LT(lbl_pos, class_num);
|
||||
int index_pos = i * class_num + lbl_pos;
|
||||
T sum = static_cast<T>(0);
|
||||
for (int j = 0; j < class_num; j++) {
|
||||
if (j == lbl_pos) continue;
|
||||
int index_neg = i * class_num + j;
|
||||
sum += TolerableValue<T>()(-std::log(
|
||||
1.0f + TolerableValue<T>()(std::exp(logits_data[index_neg] -
|
||||
logits_data[index_pos]))));
|
||||
}
|
||||
loss_data[i] = -sum / (class_num - 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BprLossGradientOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
auto* label = ctx.Input<Tensor>("Label");
|
||||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
|
||||
const int step_size = x->dims()[0];
|
||||
const int num_classes = x->dims()[1];
|
||||
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
||||
const T* dy_data = dy->data<T>();
|
||||
const T* x_data = x->data<T>();
|
||||
const int64_t* label_data = label->data<int64_t>();
|
||||
|
||||
for (size_t sample_id = 0; sample_id < step_size; sample_id++) {
|
||||
for (size_t x_offset = sample_id * num_classes;
|
||||
x_offset < (sample_id + 1) * num_classes; x_offset++) {
|
||||
dx_data[x_offset] = static_cast<T>(0);
|
||||
}
|
||||
auto p_index = sample_id * num_classes + label_data[sample_id];
|
||||
for (size_t ni = 0; ni < num_classes; ni++) {
|
||||
if (label_data[sample_id] == ni) continue;
|
||||
auto n_index = sample_id * num_classes + ni;
|
||||
auto grad_ = -dy_data[sample_id] /
|
||||
((num_classes - 1) *
|
||||
(1.0f + TolerableValue<T>()(std::exp(x_data[p_index] -
|
||||
x_data[n_index]))));
|
||||
dx_data[p_index] += grad_;
|
||||
dx_data[n_index] -= grad_;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, randomize_probability
|
||||
|
||||
|
||||
class TestBprLossOp1(OpTest):
|
||||
"""Test BprLoss with discrete one-hot labels.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "bpr_loss"
|
||||
batch_size = 40
|
||||
class_num = 5
|
||||
X = randomize_probability(batch_size, class_num, dtype='float64')
|
||||
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64")
|
||||
bpr_loss_result = []
|
||||
for i in range(batch_size):
|
||||
sum = 0.0
|
||||
for j in range(class_num):
|
||||
if j == label[i][0]:
|
||||
continue
|
||||
sum += (-np.log(1.0 + np.exp(X[i][j] - X[i][label[i][0]])))
|
||||
bpr_loss_result.append(-sum / (class_num - 1))
|
||||
bpr_loss = np.asmatrix([[x] for x in bpr_loss_result], dtype="float64")
|
||||
self.inputs = {"X": X, "Label": label}
|
||||
self.outputs = {"Y": bpr_loss}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue