parent
c1feb27f75
commit
2763f3e32f
@ -0,0 +1,119 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/smooth_l1_loss_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SmoothL1LossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext& ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
||||
"Input of SmoothL1LossOp must be initialized.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
|
||||
"Target of SmoothL1LossOp must be initialized.");
|
||||
|
||||
auto* x = ctx.Input<framework::Tensor>("X");
|
||||
auto* y = ctx.Input<framework::Tensor>("Y");
|
||||
PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
|
||||
"Dimensions of SmoothL1LossOp's input and target "
|
||||
"must be same.");
|
||||
PADDLE_ENFORCE_GE(framework::arity(x->dims()), 2,
|
||||
"Tensor rank of SmoothL1LossOp's input must be "
|
||||
"at least 2.");
|
||||
auto* inside_weight = ctx.Input<framework::Tensor>("InsideWeight");
|
||||
if (inside_weight) {
|
||||
auto* outside_weight = ctx.Input<framework::Tensor>("OutsideWeight");
|
||||
PADDLE_ENFORCE_NOT_NULL(outside_weight,
|
||||
"If weights are provided, must specify both "
|
||||
"inside and outside weights.");
|
||||
PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(),
|
||||
"Dimensions of inside weight must be same with input.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
outside_weight->dims(), x->dims(),
|
||||
"Dimensions of outside weight must be same with input.");
|
||||
}
|
||||
|
||||
auto* diff = ctx.Output<framework::Tensor>("diff");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
diff->Resize(x->dims());
|
||||
// loss is a two-rank tensor
|
||||
out->Resize({x->dims()[0], 1});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SmoothL1LossOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "Input of SmoothL1LossOp.");
|
||||
AddInput("Y", "Target of SmoothL1LossOp.");
|
||||
AddInput("InsideWeight", "Optional input to scale (X-Y).");
|
||||
AddInput("OutsideWeight", "Optinal input to scale smooth l1 loss.");
|
||||
AddOutput("diff", "Intermediate variable to cache Win*(X-Y).")
|
||||
.AsIntermediate();
|
||||
AddOutput("Out", "Final smooth l1 loss of inputs.");
|
||||
AddComment(R"DOC(
|
||||
Compute SmoothL1Loss for input and target.
|
||||
|
||||
The equation is: Out = 0.5 * (sigma * (X - Y)) ^ 2 if abs(X - Y) < 1 / sigma^2
|
||||
abs(X - Y) - 0.5 / sigma^2 otherwise
|
||||
)DOC");
|
||||
AddAttr<AttrType>("sigma", "Hyper parameter, default value is 3.0 .")
|
||||
.SetDefault(3.0);
|
||||
}
|
||||
};
|
||||
|
||||
class SmoothL1LossGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext& ctx) const override {
|
||||
auto in_dims = ctx.Input<framework::Tensor>("X")->dims();
|
||||
auto out_dims =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->dims();
|
||||
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
auto* y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
PADDLE_ENFORCE_GE(framework::arity(out_dims), 2,
|
||||
"Tensor rank of output gradient should be 2.");
|
||||
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
|
||||
"First dimension of ouptut gradient must be "
|
||||
"same with input.");
|
||||
PADDLE_ENFORCE_EQ(out_dims[1], 1,
|
||||
"Second dimension of output gradient must be 1.");
|
||||
|
||||
if (x_grad) x_grad->Resize(in_dims);
|
||||
if (y_grad) y_grad->Resize(in_dims);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp,
|
||||
ops::SmoothL1LossOpMaker<float>, ops::SmoothL1LossGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
smooth_l1_loss, ops::SmoothL1LossKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
smooth_l1_loss_grad,
|
||||
ops::SmoothL1LossGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,24 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/operators/smooth_l1_loss_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
smooth_l1_loss, ops::SmoothL1LossKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
smooth_l1_loss_grad,
|
||||
ops::SmoothL1LossGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,184 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
struct SmoothL1LossFoward {
|
||||
__host__ __device__ SmoothL1LossFoward(const T& sigma2) : sigma2(sigma2) {}
|
||||
|
||||
__host__ __device__ T operator()(const T& val) const {
|
||||
T abs_val = std::abs(val);
|
||||
if (abs_val < 1.0 / sigma2) {
|
||||
return 0.5 * val * val * sigma2;
|
||||
} else {
|
||||
return abs_val - 0.5 / sigma2;
|
||||
}
|
||||
}
|
||||
|
||||
T sigma2;
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class SmoothL1LossKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in0 = context.Input<Tensor>("X");
|
||||
auto* in1 = context.Input<Tensor>("Y");
|
||||
auto* in2 = context.Input<Tensor>("InsideWeight");
|
||||
auto* in3 = context.Input<Tensor>("OutsideWeight");
|
||||
auto* out0 = context.Output<Tensor>("diff");
|
||||
auto* out1 = context.Output<Tensor>("Out");
|
||||
|
||||
out0->mutable_data<T>(context.GetPlace());
|
||||
out1->mutable_data<T>(context.GetPlace());
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
auto sigma = static_cast<T>(context.op_.GetAttr<AttrType>("sigma"));
|
||||
T sigma2 = sigma * sigma;
|
||||
bool has_weight = (in2 != nullptr) && (in3 != nullptr);
|
||||
|
||||
auto x = EigenVector<T>::Flatten(*in0);
|
||||
auto y = EigenVector<T>::Flatten(*in1);
|
||||
auto diff = EigenVector<T>::Flatten(*out0);
|
||||
|
||||
diff.device(place) = x - y;
|
||||
// multiply inside weight
|
||||
if (has_weight) {
|
||||
auto inside_weight = EigenVector<T>::Flatten(*in2);
|
||||
// cache diff, reused in bp
|
||||
diff.device(place) = diff * inside_weight;
|
||||
}
|
||||
|
||||
auto in_counts = framework::product(in0->dims());
|
||||
Tensor paddle_errors;
|
||||
paddle_errors.mutable_data<T>({static_cast<int>(in_counts)},
|
||||
context.GetPlace());
|
||||
auto errors = EigenVector<T>::Flatten(paddle_errors);
|
||||
// apply smooth l1 forward
|
||||
errors.device(place) = diff.unaryExpr(SmoothL1LossFoward<T>(sigma2));
|
||||
|
||||
// multiply outside weight
|
||||
if (has_weight) {
|
||||
auto outside_weight = EigenVector<T>::Flatten(*in3);
|
||||
errors.device(place) = errors * outside_weight;
|
||||
}
|
||||
auto loss = EigenMatrix<T>::From(*out1, {in0->dims()[0], 1});
|
||||
// first dimension of 'X' is the number of samples
|
||||
auto errors_mat_view = EigenMatrix<T>::From(paddle_errors, in0->dims());
|
||||
loss.device(place) = errors_mat_view.sum(Eigen::array<int, 1>({1}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SmoothL1LossBackward {
|
||||
__host__ __device__ SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {}
|
||||
|
||||
__host__ __device__ T operator()(const T& val) const {
|
||||
T abs_val = std::abs(val);
|
||||
if (abs_val < 1.0 / sigma2) {
|
||||
return sigma2 * val;
|
||||
} else {
|
||||
return (0 < val) - (val < 0);
|
||||
}
|
||||
}
|
||||
|
||||
T sigma2;
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class SmoothL1LossGradKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in0 = context.Input<Tensor>("InsideWeight");
|
||||
auto* in1 = context.Input<Tensor>("OutsideWeight");
|
||||
auto* in2 = context.Input<Tensor>("diff");
|
||||
auto* og = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto sigma = static_cast<T>(context.op_.GetAttr<AttrType>("sigma"));
|
||||
T sigma2 = sigma * sigma;
|
||||
bool has_weight = (in0 != nullptr) && (in1 != nullptr);
|
||||
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
auto in_dims = in2->dims();
|
||||
auto counts = framework::product(in_dims);
|
||||
auto cols = counts / in_dims[0];
|
||||
auto mat_dims = framework::make_ddim(
|
||||
{static_cast<int>(in_dims[0]), static_cast<int>(cols)});
|
||||
|
||||
Tensor paddle_diff;
|
||||
paddle_diff.mutable_data<T>({static_cast<int>(counts)}, context.GetPlace());
|
||||
auto diff = EigenVector<T>::Flatten(paddle_diff);
|
||||
// apply smooth l1 backwoard
|
||||
diff.device(place) = EigenVector<T>::Flatten(*in2).unaryExpr(
|
||||
SmoothL1LossBackward<T>(sigma2));
|
||||
|
||||
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
// compute weights
|
||||
Tensor paddle_weights;
|
||||
paddle_weights.mutable_data<T>(mat_dims, context.GetPlace());
|
||||
auto weights = EigenMatrix<T>::From(paddle_weights);
|
||||
// initialize to 1.0
|
||||
if (platform::is_cpu_place(context.GetPlace())) {
|
||||
weights.setConstant(static_cast<T>(1.0));
|
||||
} else {
|
||||
Tensor paddle_cpu_weights;
|
||||
paddle_cpu_weights.mutable_data<T>(mat_dims, platform::CPUPlace());
|
||||
EigenMatrix<T>::From(paddle_cpu_weights).setConstant(static_cast<T>(1.0));
|
||||
paddle_weights.CopyFrom<T>(paddle_cpu_weights, context.GetPlace());
|
||||
}
|
||||
if (has_weight) {
|
||||
auto inside_weight = EigenMatrix<T>::From(*in0, mat_dims);
|
||||
auto outside_weight = EigenMatrix<T>::From(*in1, mat_dims);
|
||||
weights.device(place) = inside_weight * outside_weight;
|
||||
}
|
||||
|
||||
// compute gradients
|
||||
auto out_grad = EigenMatrix<T>::From(*og);
|
||||
auto diff_mat_view = EigenMatrix<T>::From(paddle_diff, mat_dims);
|
||||
auto gradients =
|
||||
out_grad.broadcast(Eigen::array<int, 2>({1, static_cast<int>(cols)})) *
|
||||
weights * diff_mat_view;
|
||||
|
||||
if (out0) {
|
||||
out0->mutable_data<T>(context.GetPlace());
|
||||
auto x_grad = EigenMatrix<T>::From(*out0, mat_dims);
|
||||
x_grad.device(place) = gradients;
|
||||
}
|
||||
|
||||
if (out1) {
|
||||
out1->mutable_data<T>(context.GetPlace());
|
||||
auto y_grad = EigenMatrix<T>::From(*out1, mat_dims);
|
||||
y_grad.device(place) = -1 * gradients;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,106 @@
|
||||
import unittest
|
||||
from op_test_util import OpTestMeta
|
||||
from gradient_checker import GradientChecker, create_op
|
||||
import functools
|
||||
import numpy as np
|
||||
from paddle.v2.framework.op import Operator
|
||||
|
||||
|
||||
def smooth_l1_loss_forward(val, sigma2):
|
||||
abs_val = abs(val)
|
||||
if abs_val < 1.0 / sigma2:
|
||||
return 0.5 * val * val * sigma2
|
||||
else:
|
||||
return abs_val - 0.5 / sigma2
|
||||
|
||||
|
||||
class TestSmoothL1LossOp_f0(unittest.TestCase):
|
||||
__metaclass__ = OpTestMeta
|
||||
|
||||
def setUp(self):
|
||||
self.type = "smooth_l1_loss"
|
||||
dims = (32, 64)
|
||||
self.inputs = {
|
||||
'X': np.random.random(dims).astype("float32"),
|
||||
'Y': np.random.random(dims).astype("float32")
|
||||
}
|
||||
sigma = 3.0
|
||||
self.attrs = {'sigma': sigma}
|
||||
sigma2 = sigma * sigma
|
||||
diff = self.inputs['X'] - self.inputs['Y']
|
||||
loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2).sum(1)
|
||||
loss = loss.reshape((dims[0], 1))
|
||||
self.outputs = {'diff': diff, 'Out': loss}
|
||||
|
||||
|
||||
class TestSmoothL1LossOp_f1(unittest.TestCase):
|
||||
__metaclass__ = OpTestMeta
|
||||
|
||||
def setUp(self):
|
||||
self.type = "smooth_l1_loss"
|
||||
dims = (32, 64)
|
||||
self.inputs = {
|
||||
'X': np.random.random(dims).astype("float32"),
|
||||
'Y': np.random.random(dims).astype("float32"),
|
||||
'InsideWeight': np.random.random(dims).astype("float32"),
|
||||
'OutsideWeight': np.random.random(dims).astype("float32")
|
||||
}
|
||||
sigma = 3.0
|
||||
self.attrs = {'sigma': sigma}
|
||||
sigma2 = sigma * sigma
|
||||
diff = self.inputs['X'] - self.inputs['Y']
|
||||
diff = diff * self.inputs['InsideWeight']
|
||||
loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2)
|
||||
loss = loss * self.inputs['OutsideWeight']
|
||||
loss = loss.sum(1).reshape((dims[0], 1))
|
||||
self.outputs = {'diff': diff, 'Out': loss}
|
||||
|
||||
|
||||
class SmoothL1LossGradOpTest(GradientChecker):
|
||||
def test_smooth_l1_loss_b0(self):
|
||||
dims = (5, 7)
|
||||
X = np.random.random(dims).astype("float32")
|
||||
Y = np.random.random(dims).astype("float32")
|
||||
InsideWeight = np.random.random(dims).astype("float32")
|
||||
OutsideWeight = np.random.random(dims).astype("float32")
|
||||
inputs = {
|
||||
'X': X,
|
||||
'Y': Y,
|
||||
'InsideWeight': InsideWeight,
|
||||
'OutsideWeight': OutsideWeight
|
||||
}
|
||||
op = Operator(
|
||||
"smooth_l1_loss",
|
||||
X='X',
|
||||
Y='Y',
|
||||
InsideWeight='InsideWeight',
|
||||
OutsideWeight='OutsideWeight',
|
||||
diff="diff",
|
||||
Out="Out",
|
||||
sigma=3.0)
|
||||
self.compare_grad(
|
||||
op, inputs, no_grad_set=set(['InsideWeight', 'OutsideWeight']))
|
||||
self.check_grad(
|
||||
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.08)
|
||||
|
||||
def test_smooth_l1_loss_b1(self):
|
||||
dims = (5, 7)
|
||||
X = np.random.random(dims).astype("float32")
|
||||
Y = np.random.random(dims).astype("float32")
|
||||
inputs = {'X': X, 'Y': Y}
|
||||
op = Operator(
|
||||
"smooth_l1_loss",
|
||||
X='X',
|
||||
Y='Y',
|
||||
InsideWeight='InsideWeight',
|
||||
OutsideWeight='OutsideWeight',
|
||||
diff="diff",
|
||||
Out="Out",
|
||||
sigma=3.0)
|
||||
self.compare_grad(
|
||||
op, inputs, no_grad_set=set(['InsideWeight', 'OutsideWeight']))
|
||||
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue