commit
51f1148921
@ -0,0 +1,114 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/modified_huber_loss_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class ModifiedHuberLossOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(const framework::InferShapeContext& context) const override {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(context.InputVar("X"), "X must be initialized.");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(context.InputVar("Y"), "Y must be initialized.");
|
||||||
|
|
||||||
|
auto* x = context.Input<Tensor>("X");
|
||||||
|
auto* y = context.Input<Tensor>("Y");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
|
||||||
|
"The shape of X and Y must be the same.");
|
||||||
|
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2.");
|
||||||
|
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1.");
|
||||||
|
|
||||||
|
context.Output<framework::LoDTensor>("IntermediateVal")->Resize(x->dims());
|
||||||
|
context.Output<framework::LoDTensor>("Out")->Resize({x->dims()[0], 1});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
ModifiedHuberLossOpMaker(framework::OpProto* proto,
|
||||||
|
framework::OpAttrChecker* op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("X",
|
||||||
|
"The input tensor of modified huber loss op."
|
||||||
|
"X is 2-D tensor with shape [batch_size, 1].");
|
||||||
|
AddInput("Y",
|
||||||
|
"The target labels of modified huber loss op."
|
||||||
|
"The shape of Y is same as X. Values of Y must be 0 or 1.");
|
||||||
|
AddOutput("IntermediateVal",
|
||||||
|
"Variable to save intermediate result which will be reused in "
|
||||||
|
"backward processing.")
|
||||||
|
.AsIntermediate();
|
||||||
|
AddOutput("Out", "Classification loss for X.");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Modified huber loss is used in binary classification problem. The shape of
|
||||||
|
input X and target Y are both [N, 1] and so is the shape of output loss.
|
||||||
|
Since target Y is not differentiable, cacluating gradient for Y is illegal.
|
||||||
|
The formulation of modified huber loss is:
|
||||||
|
|
||||||
|
L(y, f(x)) = max(0, 1 - yf(x))^2 for yf(x) >= -1,
|
||||||
|
-4yf(x) otherwise.
|
||||||
|
|
||||||
|
Make sure the values of target label Y are in {0, 1} here. The operator will
|
||||||
|
scale values of Y to {-1, +1} when computing losses and gradients.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(const framework::InferShapeContext& context) const override {
|
||||||
|
auto* x = context.Input<Tensor>("X");
|
||||||
|
auto* y = context.Input<Tensor>("Y");
|
||||||
|
auto* intermediate_val = context.Input<Tensor>("IntermediateVal");
|
||||||
|
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||||
|
auto* x_grad =
|
||||||
|
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized.");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized.");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(intermediate_val,
|
||||||
|
"Intermediate value must not be null.");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@Grad) must not be null.");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
intermediate_val->dims(), x->dims(),
|
||||||
|
"The shape of X and intermediate value must be the same.");
|
||||||
|
PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims(),
|
||||||
|
"The shape of Input(Out@Grad) and X must be the same.");
|
||||||
|
|
||||||
|
if (x_grad) x_grad->Resize(x->dims());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP(modified_huber_loss, ops::ModifiedHuberLossOp,
|
||||||
|
ops::ModifiedHuberLossOpMaker, modified_huber_loss_grad,
|
||||||
|
ops::ModifiedHuberLossGradOp);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
modified_huber_loss,
|
||||||
|
ops::ModifiedHuberLossKernel<paddle::platform::CPUPlace, float>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(modified_huber_loss_grad,
|
||||||
|
ops::ModifiedHuberLossGradCPUKernel<float>);
|
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/device_vector.h>
|
||||||
|
#include <thrust/for_each.h>
|
||||||
|
#include <thrust/tuple.h>
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/operators/modified_huber_loss_op.h"
|
||||||
|
#include "paddle/platform/hostdevice.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
struct ModifiedHuberLossBackward {
|
||||||
|
template <typename Tuple>
|
||||||
|
HOSTDEVICE void operator()(Tuple t) const {
|
||||||
|
auto inter_val = thrust::get<1>(t);
|
||||||
|
auto y_val = thrust::get<2>(t);
|
||||||
|
auto out_grad = thrust::get<3>(t);
|
||||||
|
if (inter_val < -1) {
|
||||||
|
thrust::get<0>(t) = -4 * (2 * y_val - 1) * out_grad;
|
||||||
|
} else if (inter_val < 1) {
|
||||||
|
thrust::get<0>(t) = -2 * (1 - inter_val) * (2 * y_val - 1) * out_grad;
|
||||||
|
} else {
|
||||||
|
thrust::get<0>(t) = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class ModifiedHuberLossGradGPUKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
auto* in0 = context.Input<Tensor>("Y");
|
||||||
|
auto* in1 = context.Input<Tensor>("IntermediateVal");
|
||||||
|
auto* in2 = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||||
|
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
|
||||||
|
if (out0) {
|
||||||
|
auto counts = framework::product(in1->dims());
|
||||||
|
auto y_ptr = thrust::device_pointer_cast(in0->data<T>());
|
||||||
|
auto inter_val_ptr = thrust::device_pointer_cast(in1->data<T>());
|
||||||
|
auto out_grad_ptr = thrust::device_pointer_cast(in2->data<T>());
|
||||||
|
thrust::device_ptr<T> x_grad_ptr(
|
||||||
|
out0->mutable_data<T>(context.GetPlace()));
|
||||||
|
|
||||||
|
auto iter_begin = thrust::make_zip_iterator(
|
||||||
|
thrust::make_tuple(x_grad_ptr, inter_val_ptr, y_ptr, out_grad_ptr));
|
||||||
|
|
||||||
|
auto iter_end = thrust::make_zip_iterator(
|
||||||
|
thrust::make_tuple(x_grad_ptr + counts, inter_val_ptr + counts,
|
||||||
|
y_ptr + counts, out_grad_ptr + counts));
|
||||||
|
|
||||||
|
thrust::for_each(iter_begin, iter_end, ModifiedHuberLossBackward());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_GPU_KERNEL(
|
||||||
|
modified_huber_loss,
|
||||||
|
ops::ModifiedHuberLossKernel<paddle::platform::GPUPlace, float>);
|
||||||
|
REGISTER_OP_GPU_KERNEL(modified_huber_loss_grad,
|
||||||
|
ops::ModifiedHuberLossGradGPUKernel<float>);
|
@ -0,0 +1,107 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/eigen.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/platform/hostdevice.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct CheckLabelValue {
|
||||||
|
HOSTDEVICE T operator()(const T& val) const {
|
||||||
|
PADDLE_ASSERT(val == static_cast<T>(0) || val == static_cast<T>(1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ModifiedHuberLossForward {
|
||||||
|
HOSTDEVICE T operator()(const T& val) const {
|
||||||
|
if (val < -1) {
|
||||||
|
return -4 * val;
|
||||||
|
} else if (val < 1) {
|
||||||
|
return (1 - val) * (1 - val);
|
||||||
|
} else {
|
||||||
|
return static_cast<T>(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class ModifiedHuberLossKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
auto* in0 = context.Input<Tensor>("X");
|
||||||
|
auto* in1 = context.Input<Tensor>("Y");
|
||||||
|
auto* out0 = context.Output<framework::LoDTensor>("IntermediateVal");
|
||||||
|
auto* out1 = context.Output<framework::LoDTensor>("Out");
|
||||||
|
|
||||||
|
out0->mutable_data<T>(context.GetPlace());
|
||||||
|
out1->mutable_data<T>(context.GetPlace());
|
||||||
|
auto place = context.GetEigenDevice<Place>();
|
||||||
|
|
||||||
|
auto x = EigenVector<T>::Flatten(*in0);
|
||||||
|
auto y = EigenVector<T>::Flatten(*in1);
|
||||||
|
// make sure value's of Y in {0, 1}
|
||||||
|
y.unaryExpr(CheckLabelValue<T>());
|
||||||
|
auto inter_val = EigenVector<T>::Flatten(*out0);
|
||||||
|
// scale y to {-1, +1} and compute x * y
|
||||||
|
inter_val.device(place) = x * (2 * y - static_cast<T>(1));
|
||||||
|
auto loss = EigenVector<T>::Flatten(*out1);
|
||||||
|
loss.device(place) = inter_val.unaryExpr(ModifiedHuberLossForward<T>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// CPU backward kernel
|
||||||
|
template <typename T>
|
||||||
|
class ModifiedHuberLossGradCPUKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
auto* in0 = context.Input<Tensor>("Y");
|
||||||
|
auto* in1 = context.Input<framework::LoDTensor>("IntermediateVal");
|
||||||
|
auto* in2 =
|
||||||
|
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
|
||||||
|
auto* out0 =
|
||||||
|
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
||||||
|
|
||||||
|
if (out0) {
|
||||||
|
const T* y_ptr = in0->data<T>();
|
||||||
|
const T* inter_val_ptr = in1->data<T>();
|
||||||
|
const T* out_grad_ptr = in2->data<T>();
|
||||||
|
size_t counts = static_cast<size_t>(framework::product(in1->dims()));
|
||||||
|
T* x_grad_ptr = out0->mutable_data<T>(context.GetPlace());
|
||||||
|
for (size_t i = 0; i < counts; ++i) {
|
||||||
|
if (inter_val_ptr[i] < -1) {
|
||||||
|
x_grad_ptr[i] = -4 * (2 * y_ptr[i] - 1) * out_grad_ptr[i];
|
||||||
|
} else if (inter_val_ptr[i] < 1) {
|
||||||
|
x_grad_ptr[i] = -2 * (1 - inter_val_ptr[i]) * (2 * y_ptr[i] - 1) *
|
||||||
|
out_grad_ptr[i];
|
||||||
|
} else {
|
||||||
|
x_grad_ptr[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,39 @@
|
|||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
def modified_huber_loss_forward(val):
|
||||||
|
if val < -1:
|
||||||
|
return -4 * val
|
||||||
|
elif val < 1:
|
||||||
|
return (1 - val) * (1 - val)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestModifiedHuberLossOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = 'modified_huber_loss'
|
||||||
|
samples_num = 32
|
||||||
|
self.inputs = {
|
||||||
|
'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'),
|
||||||
|
'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1))
|
||||||
|
}
|
||||||
|
product_res = self.inputs['X'] * (2 * self.inputs['Y'] - 1)
|
||||||
|
loss = np.vectorize(modified_huber_loss_forward)(product_res)
|
||||||
|
|
||||||
|
self.outputs = {
|
||||||
|
'IntermediateVal': product_res,
|
||||||
|
'Out': loss.reshape((samples_num, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['X'], 'Out', max_relative_error=0.005)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue