parent
0be349496f
commit
1c81d57938
@ -0,0 +1,108 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/huber_loss_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class HuberLossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext& ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Y must be initialized.");
|
||||
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* y = ctx.Input<Tensor>("Y");
|
||||
|
||||
PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
|
||||
"Dimensions of X and Y must be the same.");
|
||||
// we constraint shape of X to (N, 1), may expand to (N, x, ...) if needed
|
||||
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2,
|
||||
"Tensor rank of X must be 2.");
|
||||
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "Second dimension of X must be 1.");
|
||||
|
||||
ctx.Output<Tensor>("residual")->Resize(x->dims());
|
||||
ctx.Output<Tensor>("Out")->Resize({x->dims()[0], 1});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
HuberLossOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "Input value of HuberLossOp.");
|
||||
AddInput("Y", "Target value of HuberLossOp.");
|
||||
AddOutput("residual",
|
||||
"Save residual value between Y and X. "
|
||||
"Will be reused in backward.")
|
||||
.AsIntermediate();
|
||||
AddOutput("Out", "Huber loss between input and target.");
|
||||
AddAttr<AttrType>("delta", "Hyper parameter in huber loss.");
|
||||
AddComment(R"DOC(
|
||||
Huber loss is a loss function used in robust regression. We constrain shape of
|
||||
input to (N, 1). The formulation is:
|
||||
|
||||
L_delta(y, f(x)) = 0.5 * (y - f(x))^2 for |y - f(x)| <= delta,
|
||||
delta * (|y - f(x)| - 0.5 * delta) otherwise.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class HuberLossGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext& ctx) const override {
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* y = ctx.Input<Tensor>("Y");
|
||||
auto* residual = ctx.Input<Tensor>("residual");
|
||||
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
PADDLE_ENFORCE_NOT_NULL(x, "Input X must not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(y, "Target Y must not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(residual, "Residual value must not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(out_grad, "Out gradient must not be null.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(residual->dims(), x->dims(),
|
||||
"Dimension of X and residual value must be the same.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
out_grad->dims(), x->dims(),
|
||||
"Dimension of Out gradient and X must be the same (N*1).");
|
||||
|
||||
if (x_grad) x_grad->Resize(x->dims());
|
||||
if (y_grad) y_grad->Resize(y->dims());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
|
||||
huber_loss_grad, ops::HuberLossGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(huber_loss,
|
||||
ops::HuberLossKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
huber_loss_grad,
|
||||
ops::HuberLossGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/huber_loss_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(huber_loss,
|
||||
ops::HuberLossKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
huber_loss_grad,
|
||||
ops::HuberLossGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,120 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
struct HuberLossForward {
|
||||
HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {}
|
||||
|
||||
HOSTDEVICE T operator()(const T& val) const {
|
||||
T abs_val = std::abs(val);
|
||||
if (abs_val <= delta) {
|
||||
return 0.5 * val * val;
|
||||
} else {
|
||||
return delta * (abs_val - 0.5 * delta);
|
||||
}
|
||||
}
|
||||
|
||||
T delta;
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class HuberLossKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in0 = context.Input<Tensor>("X");
|
||||
auto* in1 = context.Input<Tensor>("Y");
|
||||
auto* out0 = context.Output<Tensor>("residual");
|
||||
auto* out1 = context.Output<Tensor>("Out");
|
||||
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
auto x = EigenVector<T>::Flatten(*in0);
|
||||
auto y = EigenVector<T>::Flatten(*in1);
|
||||
out0->mutable_data<T>(context.GetPlace());
|
||||
auto residual = EigenVector<T>::Flatten(*out0);
|
||||
residual.device(place) = y - x;
|
||||
out1->mutable_data<T>(context.GetPlace());
|
||||
auto loss = EigenVector<T>::Flatten(*out1);
|
||||
loss.device(place) = residual.unaryExpr(HuberLossForward<T>(delta));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct HuberLossBackward {
|
||||
HOSTDEVICE HuberLossBackward(const T& delta, bool is_x)
|
||||
: is_x(is_x), delta(delta) {}
|
||||
|
||||
HOSTDEVICE T operator()(const T& val) const {
|
||||
T sign = is_x ? -1.0 : 1.0;
|
||||
T abs_val = std::abs(val);
|
||||
if (abs_val <= delta) {
|
||||
return sign * val;
|
||||
} else {
|
||||
if (val > 0) {
|
||||
return sign * delta;
|
||||
} else {
|
||||
return -1 * sign * delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool is_x;
|
||||
T delta;
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class HuberLossGradKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in0 = context.Input<Tensor>("residual");
|
||||
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
auto residual = EigenVector<T>::Flatten(*in0);
|
||||
auto out_grad = EigenVector<T>::Flatten(*in1);
|
||||
|
||||
if (out0) {
|
||||
out0->mutable_data<T>(context.GetPlace());
|
||||
auto x_grad = EigenVector<T>::Flatten(*out0);
|
||||
x_grad.device(place) =
|
||||
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, true));
|
||||
}
|
||||
|
||||
if (out1) {
|
||||
out1->mutable_data<T>(context.GetPlace());
|
||||
auto y_grad = EigenVector<T>::Flatten(*out1);
|
||||
y_grad.device(place) =
|
||||
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, false));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,56 @@
|
||||
import unittest
|
||||
from op_test_util import OpTestMeta
|
||||
from gradient_checker import GradientChecker, create_op
|
||||
from paddle.v2.framework.op import Operator
|
||||
import numpy as np
|
||||
|
||||
|
||||
def huber_loss_forward(val, delta):
|
||||
abs_val = abs(val)
|
||||
if abs_val <= delta:
|
||||
return 0.5 * val * val
|
||||
else:
|
||||
return delta * (abs_val - 0.5 * delta)
|
||||
|
||||
|
||||
class TestHuberLossOp(unittest.TestCase):
|
||||
__metaclass__ = OpTestMeta
|
||||
|
||||
def setUp(self):
|
||||
self.type = 'huber_loss'
|
||||
samples_num = 64
|
||||
delta = 1.0
|
||||
self.inputs = {
|
||||
'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
|
||||
'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
|
||||
}
|
||||
residual = self.inputs['Y'] - self.inputs['X']
|
||||
loss = np.vectorize(huber_loss_forward)(residual, delta)
|
||||
self.attrs = {'delta': delta}
|
||||
self.outputs = {
|
||||
'residual': residual,
|
||||
'Out': loss.reshape((samples_num, 1))
|
||||
}
|
||||
|
||||
|
||||
class TestHuberLossGradOp(GradientChecker):
|
||||
def test_huber_loss(self):
|
||||
samples_num = 10
|
||||
delta = 1.0
|
||||
inputs = {
|
||||
'X': np.random.uniform(-1, 1, (samples_num, 1)).astype('float32'),
|
||||
'Y': np.random.uniform(-1, 1, (samples_num, 1)).astype('float32')
|
||||
}
|
||||
op = Operator(
|
||||
"huber_loss",
|
||||
X='X',
|
||||
Y='Y',
|
||||
residual='residual',
|
||||
delta=delta,
|
||||
Out='Out')
|
||||
self.compare_grad(op, inputs, no_grad_set=set(['residual']))
|
||||
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue