add sigmoid focal loss operator for supporting retinanet (#17895)
* test=develop add sigmoid_focal_loss for supporting retinanet * test=develop add test_layers * test=develop add API.spc * test=develop alter sigmoid_focal_loss_op.cc * test=develop alter detection.py * test=develop alter API.spec * test=develop alter round 1 * test=develop alter simooid_focal_loss * test=develop alter sigmoid_focal_loss_op.cc * test=develop alter test_layers.py * test=develop alter paddle/fluid/API.spec * test=develop alter sigmoid_focal_loss_op.cu * test=develop alter paddle/fluid/operators/detection/sigmoid_focal_loss_op.ccrevert-18229-add_multi_gpu_install_check
parent
9e4b9d9798
commit
0aee1f0074
@ -0,0 +1,208 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class SigmoidFocalLossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto labels_dims = ctx->GetInputDim("Label");
|
||||
auto fg_dims = ctx->GetInputDim("FgNum");
|
||||
|
||||
int rank = x_dims.size();
|
||||
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
|
||||
"Input(X) and Input(Label) shall have the same rank.");
|
||||
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
|
||||
bool check = true;
|
||||
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
|
||||
framework::product(labels_dims) <= 0)) {
|
||||
check = false;
|
||||
}
|
||||
|
||||
if (check) {
|
||||
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
||||
framework::slice_ddim(labels_dims, 0, rank - 1),
|
||||
"Input(X) and Input(Label) shall have the same shape "
|
||||
"except the last dimension.");
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
|
||||
"The last dimension of input(Label) should be 1.");
|
||||
|
||||
ctx->ShareDim("X", /*->*/ "Out");
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class SigmoidFocalLossGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("FgNum"), "Input(FgNum) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Output(X@GRAD) should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto labels_dims = ctx->GetInputDim("Label");
|
||||
auto fg_dims = ctx->GetInputDim("FgNum");
|
||||
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
||||
|
||||
int rank = x_dims.size();
|
||||
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
|
||||
"Input(X) and Input(Label) shall have the same rank.");
|
||||
PADDLE_ENFORCE_EQ(fg_dims.size(), 1, "The rank of Input(FgNum) must be 1.");
|
||||
bool check = true;
|
||||
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
|
||||
framework::product(labels_dims) <= 0)) {
|
||||
check = false;
|
||||
}
|
||||
|
||||
if (check) {
|
||||
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
||||
framework::slice_ddim(labels_dims, 0, rank - 1),
|
||||
"Input(X) and Input(Label) shall have the same shape.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
|
||||
"The last dimension of input(Label) should be 1.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
framework::slice_ddim(x_dims, 0, rank),
|
||||
framework::slice_ddim(dout_dims, 0, rank),
|
||||
"Input(X) and Input(Out@Grad) shall have the same shape.");
|
||||
}
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class SigmoidFocalLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D], "
|
||||
"where N is the batch size and D is the number of classes "
|
||||
"(excluding background). This input is a tensor of logits "
|
||||
"computed by the previous operator.");
|
||||
AddInput("Label",
|
||||
"(Tensor, default Tensor<int>), a 2-D tensor with shape [N, 1]. "
|
||||
"This input is a tensor of probabilistic labels.");
|
||||
AddInput("FgNum",
|
||||
"(Tensor, default Tensor<int>), a 1-D tensor with shape [1]. "
|
||||
"This input is the number of foreground.");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N, D]. "
|
||||
"This output is the focal loss.");
|
||||
AddAttr<float>(
|
||||
"gamma",
|
||||
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
|
||||
"easy and hard examples. "
|
||||
"A float scalar with default value 2.0.")
|
||||
.SetDefault(2.0);
|
||||
AddAttr<float>(
|
||||
"alpha",
|
||||
"Hyper-parameter of sigmoid focal loss op, which is to balance the "
|
||||
"positive and negative examples. "
|
||||
"A float scalar with default value 0.5.")
|
||||
.SetDefault(0.25);
|
||||
AddComment(R"DOC(
|
||||
Sigmoid Focal Loss Operator.
|
||||
|
||||
Focal loss is used to address the foreground-background class imbalance existed
|
||||
on the training phase of one-stage detectors. This operator computes the sigmoid
|
||||
value for each element in the input tensor, after which focal loss is measured.
|
||||
|
||||
The focal loss is given as follows:
|
||||
|
||||
$$Loss_j = (-Label_j * alpha * \pow(1 - \sigma(X_j), gamma) * \log(\sigma(X_j)) -
|
||||
(1 - Labels_j) * (1 - alpha) * \pow(\sigma(X_j), gamma) * \log(1 - \sigma(X_j)))
|
||||
/ FgNum, j = 1,...,K$$
|
||||
|
||||
We know that $$\sigma(X_j) = \\frac{1}{1 + \exp(-X_j)}$$.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SigmoidFocalLossGradOpDescMaker
|
||||
: public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
||||
op->SetType("sigmoid_focal_loss_grad");
|
||||
op->SetInput("X", Input("X"));
|
||||
op->SetInput("Label", Input("Label"));
|
||||
op->SetInput("FgNum", Input("FgNum"));
|
||||
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
op->SetAttrMap(Attrs());
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(sigmoid_focal_loss, ops::SigmoidFocalLossOp,
|
||||
ops::SigmoidFocalLossOpMaker,
|
||||
ops::SigmoidFocalLossGradOpDescMaker);
|
||||
REGISTER_OPERATOR(sigmoid_focal_loss_grad, ops::SigmoidFocalLossGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sigmoid_focal_loss,
|
||||
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SigmoidFocalLossKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
sigmoid_focal_loss_grad,
|
||||
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::SigmoidFocalLossGradKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
@ -0,0 +1,181 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "cub/cub.cuh"
|
||||
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
|
||||
#include "paddle/fluid/operators/math.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
static constexpr int kNumCUDAThreads = 512;
|
||||
static constexpr int kNumMaxinumNumBlocks = 4096;
|
||||
|
||||
static inline int NumBlocks(const int N) {
|
||||
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
|
||||
kNumMaxinumNumBlocks);
|
||||
}
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__global__ void GPUSigmoidFocalLossForward(const T *x_data,
|
||||
const int *label_data,
|
||||
const int *fg_num_data,
|
||||
const T gamma, const T alpha,
|
||||
const int num_classes,
|
||||
const int limit, T *out_data) {
|
||||
CUDA_1D_KERNEL_LOOP(i, limit) {
|
||||
T x = x_data[i];
|
||||
int a = i / num_classes; // current sample
|
||||
int d = i % num_classes; // current class
|
||||
int g = label_data[a]; // target
|
||||
|
||||
// check whether the input data is positive or negative
|
||||
// the target classes are in range 1-81
|
||||
// and the d is in range 0-80
|
||||
T c_pos = static_cast<T>(g == (d + 1));
|
||||
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
|
||||
|
||||
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
|
||||
T s_neg = (1.0 - alpha) / fg_num;
|
||||
T s_pos = alpha / fg_num;
|
||||
|
||||
// p = 1. / 1. + expf(-x)
|
||||
T p = 1. / (1. + real_exp(-x));
|
||||
|
||||
// (1 - p)**gamma * log(p)
|
||||
T term_pos =
|
||||
std::pow((1. - p), gamma) * real_log(p > FLT_MIN ? p : FLT_MIN);
|
||||
// p**gamma * log(1 - p)
|
||||
T term_neg =
|
||||
std::pow(p, gamma) *
|
||||
(-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0))));
|
||||
|
||||
out_data[i] = 0.0;
|
||||
out_data[i] += -c_pos * term_pos * s_pos;
|
||||
out_data[i] += -c_neg * term_neg * s_neg;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void GPUSigmoidFocalLossBackward(
|
||||
const T *x_data, const int *label_data, const int *fg_num_data,
|
||||
const T gamma, const T alpha, const int num_classes, const T *dout_data,
|
||||
const int limit, T *dx_data) {
|
||||
CUDA_1D_KERNEL_LOOP(i, limit) {
|
||||
T x = x_data[i];
|
||||
T dout = dout_data[i];
|
||||
|
||||
int a = i / num_classes; // current sample
|
||||
int d = i % num_classes; // current class
|
||||
|
||||
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
|
||||
T s_neg = (1.0 - alpha) / fg_num;
|
||||
T s_pos = alpha / fg_num;
|
||||
|
||||
int g = label_data[a];
|
||||
T c_pos = static_cast<T>(g == (d + 1));
|
||||
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
|
||||
|
||||
T p = 1. / (1. + real_exp(-x));
|
||||
|
||||
// (1-p)**g * (1 - p - g*p*log(p))
|
||||
T term_pos = std::pow((1. - p), gamma) *
|
||||
(1. - p - (p * gamma * real_log(p > FLT_MIN ? p : FLT_MIN)));
|
||||
// (p**g) * (g*(1-p)*log(1-p) - p)
|
||||
T term_neg =
|
||||
std::pow(p, gamma) *
|
||||
((-1. * x * (x >= 0) - real_log(1. + real_exp(x - 2. * x * (x >= 0)))) *
|
||||
(1. - p) * gamma -
|
||||
p);
|
||||
|
||||
dx_data[i] = 0.0;
|
||||
dx_data[i] += -c_pos * s_pos * term_pos;
|
||||
dx_data[i] += -c_neg * s_neg * term_neg;
|
||||
dx_data[i] = dx_data[i] * dout;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GPUSigmoidFocalLossKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
const Tensor *X = context.Input<Tensor>("X");
|
||||
const Tensor *Labels = context.Input<Tensor>("Label");
|
||||
const Tensor *FgNum = context.Input<Tensor>("FgNum");
|
||||
Tensor *Out = context.Output<Tensor>("Out");
|
||||
T gamma = static_cast<T>(context.Attr<float>("gamma"));
|
||||
T alpha = static_cast<T>(context.Attr<float>("alpha"));
|
||||
auto x_dims = X->dims();
|
||||
int num_classes = static_cast<int>(x_dims[1]);
|
||||
auto out_data = Out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto &dev_ctx = context.cuda_device_context();
|
||||
|
||||
int limit = Out->numel();
|
||||
int blocks = NumBlocks(limit);
|
||||
int threads = kNumCUDAThreads;
|
||||
GPUSigmoidFocalLossForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
|
||||
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
|
||||
num_classes, limit, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GPUSigmoidFocalLossGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
const Tensor *X = context.Input<Tensor>("X");
|
||||
const Tensor *Labels = context.Input<Tensor>("Label");
|
||||
const Tensor *FgNum = context.Input<Tensor>("FgNum");
|
||||
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto dx_data = dX->mutable_data<T>(context.GetPlace());
|
||||
T gamma = static_cast<T>(context.Attr<float>("gamma"));
|
||||
T alpha = static_cast<T>(context.Attr<float>("alpha"));
|
||||
auto x_dims = X->dims();
|
||||
int num_classes = static_cast<int>(x_dims[1]);
|
||||
|
||||
auto &dev_ctx = context.cuda_device_context();
|
||||
|
||||
int limit = dX->numel();
|
||||
int blocks = NumBlocks(limit);
|
||||
int threads = kNumCUDAThreads;
|
||||
GPUSigmoidFocalLossBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
|
||||
X->data<T>(), Labels->data<int>(), FgNum->data<int>(), gamma, alpha,
|
||||
num_classes, dOut->data<T>(), limit, dx_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sigmoid_focal_loss,
|
||||
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::GPUSigmoidFocalLossKernel<paddle::platform::CUDADeviceContext,
|
||||
double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sigmoid_focal_loss_grad,
|
||||
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
|
||||
float>,
|
||||
ops::GPUSigmoidFocalLossGradKernel<paddle::platform::CUDADeviceContext,
|
||||
double>);
|
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SigmoidFocalLossKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
const Tensor *X = context.Input<Tensor>("X");
|
||||
const Tensor *Labels = context.Input<Tensor>("Label");
|
||||
const Tensor *FgNum = context.Input<Tensor>("FgNum");
|
||||
Tensor *Out = context.Output<Tensor>("Out");
|
||||
T gamma = static_cast<T>(context.Attr<float>("gamma"));
|
||||
T alpha = static_cast<T>(context.Attr<float>("alpha"));
|
||||
auto out_data = Out->mutable_data<T>(context.GetPlace());
|
||||
int limit = Out->numel();
|
||||
auto x_data = X->data<T>();
|
||||
auto label_data = Labels->data<int>();
|
||||
auto fg_num_data = FgNum->data<int>();
|
||||
auto x_dims = X->dims();
|
||||
int num_classes = static_cast<int>(x_dims[1]);
|
||||
|
||||
for (int idx = 0; idx < limit; ++idx) {
|
||||
T x = x_data[idx];
|
||||
int a = idx / num_classes; // current sample
|
||||
int d = idx % num_classes; // current class
|
||||
int g = label_data[a]; // target
|
||||
|
||||
// Check whether the input data is positive or negative
|
||||
// The target classes are in range 1-81
|
||||
// and the d is in range 0-80
|
||||
T c_pos = static_cast<T>(g == (d + 1));
|
||||
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
|
||||
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
|
||||
T s_neg = (1.0 - alpha) / fg_num;
|
||||
T s_pos = alpha / fg_num;
|
||||
|
||||
// p = 1. / 1. + expf(-x)
|
||||
T p = 1. / (1. + std::exp(-x));
|
||||
|
||||
// (1 - p)**gamma * log(p) where
|
||||
T term_pos =
|
||||
std::pow((1. - p), gamma) * std::log(p > FLT_MIN ? p : FLT_MIN);
|
||||
// p**gamma * log(1 - p)
|
||||
float term_neg =
|
||||
std::pow(p, gamma) *
|
||||
(-1. * x * (x >= 0) - std::log(1. + std::exp(x - 2. * x * (x >= 0))));
|
||||
out_data[idx] = 0.0;
|
||||
out_data[idx] += -c_pos * term_pos * s_pos;
|
||||
out_data[idx] += -c_neg * term_neg * s_neg;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SigmoidFocalLossGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
const Tensor *X = context.Input<Tensor>("X");
|
||||
const Tensor *Labels = context.Input<Tensor>("Label");
|
||||
const Tensor *FgNum = context.Input<Tensor>("FgNum");
|
||||
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto dx_data = dX->mutable_data<T>(context.GetPlace());
|
||||
T gamma = static_cast<T>(context.Attr<float>("gamma"));
|
||||
T alpha = static_cast<T>(context.Attr<float>("alpha"));
|
||||
auto x_dims = X->dims();
|
||||
int num_classes = static_cast<int>(x_dims[1]);
|
||||
|
||||
int limit = dX->numel();
|
||||
auto x_data = X->data<T>();
|
||||
auto label_data = Labels->data<int>();
|
||||
auto fg_num_data = FgNum->data<int>();
|
||||
auto dout_data = dOut->data<T>();
|
||||
for (int idx = 0; idx < limit; ++idx) {
|
||||
T x = x_data[idx];
|
||||
int a = idx / num_classes; // current sample
|
||||
int d = idx % num_classes; // current class
|
||||
|
||||
T fg_num = static_cast<T>((fg_num_data[0] > 1) ? fg_num_data[0] : 1);
|
||||
T s_neg = static_cast<T>((1.0 - alpha) / fg_num);
|
||||
T s_pos = alpha / fg_num;
|
||||
int g = label_data[a];
|
||||
|
||||
T c_pos = static_cast<T>(g == (d + 1));
|
||||
T c_neg = static_cast<T>((g != -1) & (g != (d + 1)));
|
||||
T p = 1. / (1. + std::exp(-x));
|
||||
|
||||
// (1-p)**g * (1 - p - g*p*log(p))
|
||||
T term_pos = std::pow((1. - p), gamma) *
|
||||
(1. - p - (p * gamma * std::log(p > FLT_MIN ? p : FLT_MIN)));
|
||||
// (p**g) * (g*(1-p)*log(1-p) - p)
|
||||
T term_neg = std::pow(p, gamma) *
|
||||
((-1. * x * (x >= 0) -
|
||||
std::log(1. + std::exp(x - 2. * x * (x >= 0)))) *
|
||||
(1. - p) * gamma -
|
||||
p);
|
||||
|
||||
dx_data[idx] = 0.0;
|
||||
dx_data[idx] += -c_pos * s_pos * term_pos;
|
||||
dx_data[idx] += -c_neg * s_neg * term_neg;
|
||||
dx_data[idx] = dx_data[idx] * dout_data[idx];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import math
|
||||
import copy
|
||||
from op_test import OpTest
|
||||
from paddle.fluid import core
|
||||
|
||||
|
||||
def sigmoid_focal_loss_forward(x_data, label_data, fg_num_data, gamma, alpha,
|
||||
num_classes):
|
||||
x_data_t = copy.deepcopy(x_data)
|
||||
out_data = copy.deepcopy(x_data)
|
||||
x_width = len(x_data)
|
||||
x_height = len(x_data[0, :])
|
||||
x_data_t = x_data_t.flatten()
|
||||
out_data = out_data.flatten()
|
||||
for idx in range(len(x_data_t)):
|
||||
x = x_data_t[idx]
|
||||
a = int(idx / num_classes)
|
||||
d = int(idx % num_classes)
|
||||
label = label_data[a]
|
||||
c_pos = float((int(label) == int(d + 1)))
|
||||
c_neg = float(((int(label) != -1) & (int(label) != (d + 1))))
|
||||
fg_num = max(fg_num_data, 1)
|
||||
z_neg = (1.0 - alpha) / fg_num
|
||||
z_pos = alpha / fg_num
|
||||
|
||||
p = 1. / (1. + math.exp(-x))
|
||||
FLT_MIN = 1.175494351e-38
|
||||
term_pos = math.pow((1. - p), gamma) * math.log(max(FLT_MIN, p))
|
||||
term_neg = math.pow(p, gamma) * (
|
||||
-1. * x * (x >= 0) - math.log(1. + math.exp(x - 2. * x * (x >= 0))))
|
||||
out_data[idx] = 0.0
|
||||
out_data[idx] += -c_pos * term_pos * z_pos
|
||||
out_data[idx] += -c_neg * term_neg * z_neg
|
||||
|
||||
out_data = out_data.reshape(x_width, x_height)
|
||||
return out_data
|
||||
|
||||
|
||||
class TestSigmoidFocalLossOp1(OpTest):
|
||||
def set_argument(self):
|
||||
self.num_anchors = 10
|
||||
self.num_classes = 10
|
||||
self.gamma = 2.0
|
||||
self.alpha = 0.25
|
||||
|
||||
def setUp(self):
|
||||
self.set_argument()
|
||||
|
||||
dims = (self.num_anchors, self.num_classes)
|
||||
X = np.random.standard_normal(dims).astype("float32")
|
||||
L = np.random.randint(0, self.num_classes + 1,
|
||||
(dims[0], 1)).astype("int32")
|
||||
F = np.zeros(1)
|
||||
F[0] = len(np.where(L > 0)[0])
|
||||
F = F.astype("int32")
|
||||
|
||||
self.op_type = "sigmoid_focal_loss"
|
||||
self.inputs = {
|
||||
'X': X,
|
||||
'Label': L,
|
||||
'FgNum': F,
|
||||
}
|
||||
self.attrs = {
|
||||
'gamma': self.gamma,
|
||||
'alpha': self.alpha,
|
||||
}
|
||||
loss = sigmoid_focal_loss_forward(
|
||||
self.inputs['X'], self.inputs['Label'], self.inputs['FgNum'],
|
||||
self.gamma, self.alpha, self.num_classes)
|
||||
self.outputs = {'Out': loss.astype('float32')}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
||||
"core is not compiled with CUDA")
|
||||
class TestSigmoidFocalLossOp2(TestSigmoidFocalLossOp1):
|
||||
def test_check_output(self):
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_output_with_place(place, atol=2e-3)
|
||||
|
||||
def test_check_grad(self):
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['X'], 'Out', max_relative_error=0.002)
|
||||
|
||||
|
||||
class TestSigmoidFocalLossOp3(TestSigmoidFocalLossOp1):
|
||||
def set_argument(self):
|
||||
self.num_anchors = 200
|
||||
self.num_classes = 10
|
||||
self.gamma = 1.0
|
||||
self.alpha = 0.5
|
||||
|
||||
|
||||
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
||||
"core is not compiled with CUDA")
|
||||
class TestSigmoidFocalLossOp4(TestSigmoidFocalLossOp3):
|
||||
def test_check_output(self):
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_output_with_place(place, atol=2e-3)
|
||||
|
||||
def test_check_grad(self):
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['X'], 'Out', max_relative_error=0.002)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue