Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_merge_model_scripts
commit
ebf606a2a0
@ -0,0 +1,85 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/auc_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class AucOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Inference"),
|
||||
"Input of Inference must be initialized.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
||||
"Input of Label must be initialized.");
|
||||
auto inference_dim = ctx->GetInputDim("Inference");
|
||||
auto label_dim = ctx->GetInputDim("Label");
|
||||
|
||||
PADDLE_ENFORCE_EQ(inference_dim, label_dim,
|
||||
"inference and label should have same shape");
|
||||
|
||||
ctx->SetOutputDim("AUC", {1});
|
||||
ctx->ShareLoD("Inference", /*->*/ "AUC");
|
||||
}
|
||||
};
|
||||
|
||||
class AucOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Inference",
|
||||
"A floating point tensor of arbitrary shape and whose values"
|
||||
"are in the range [0, 1].");
|
||||
AddInput("Label",
|
||||
"A tensor whose shape matches "
|
||||
"Inference. Will be cast to bool.");
|
||||
// TODO(typhoonzero): support weight input
|
||||
AddOutput("AUC",
|
||||
"A scalar representing the "
|
||||
"current area-under-curve.");
|
||||
|
||||
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
|
||||
.SetDefault("ROC");
|
||||
AddAttr<int>("num_thresholds",
|
||||
"The number of thresholds to use when discretizing the"
|
||||
" roc curve.")
|
||||
.SetDefault(200);
|
||||
|
||||
AddComment(
|
||||
R"DOC(Computes the AUC according forward output and label.
|
||||
Best to use for binary classification evaluations.
|
||||
|
||||
If input label contains values other than 0 and 1, it will be cast
|
||||
to bool.
|
||||
|
||||
You can find the definations here:
|
||||
https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
|
||||
|
||||
Possible curves are:
|
||||
- ROC: Receiver operating characteristic
|
||||
- PR: Precision Recall
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,135 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class AucKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* inference = ctx.Input<Tensor>("Inference");
|
||||
auto* label = ctx.Input<Tensor>("Label");
|
||||
auto* auc = ctx.Output<Tensor>("AUC");
|
||||
|
||||
float* auc_data = auc->mutable_data<float>(ctx.GetPlace());
|
||||
|
||||
std::string curve = ctx.Attr<std::string>("curve");
|
||||
int num_thresholds = ctx.Attr<int>("num_thresholds");
|
||||
std::vector<float> thresholds_list;
|
||||
thresholds_list.reserve(num_thresholds);
|
||||
for (int i = 1; i < num_thresholds - 1; i++) {
|
||||
thresholds_list[i] = (float)i / (num_thresholds - 1);
|
||||
}
|
||||
const float kEpsilon = 1e-7;
|
||||
thresholds_list[0] = 0.0f - kEpsilon;
|
||||
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
|
||||
|
||||
size_t num_samples = inference->numel();
|
||||
|
||||
const T* inference_data = inference->data<T>();
|
||||
Tensor label_casted;
|
||||
label_casted.Resize(label->dims());
|
||||
bool* label_casted_data = label_casted.mutable_data<bool>(ctx.GetPlace());
|
||||
|
||||
const int* label_data = label->data<int>();
|
||||
// cast label_data to bool
|
||||
for (size_t i = 0; i < num_samples; i++) {
|
||||
label_casted_data[i] = static_cast<bool>(label_data[i]);
|
||||
}
|
||||
|
||||
// Create local tensor for storing the curve: TP, FN, TN, FP
|
||||
// TODO(typhoonzero): use eigen op to caculate these values.
|
||||
Tensor true_positive, false_positive, true_negative, false_negative;
|
||||
|
||||
true_positive.Resize({num_thresholds});
|
||||
false_negative.Resize({num_thresholds});
|
||||
true_negative.Resize({num_thresholds});
|
||||
false_positive.Resize({num_thresholds});
|
||||
|
||||
int* tp_data = true_positive.mutable_data<int>(ctx.GetPlace());
|
||||
int* fn_data = false_negative.mutable_data<int>(ctx.GetPlace());
|
||||
int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace());
|
||||
int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace());
|
||||
|
||||
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
|
||||
// caculate TP, FN, TN, FP for current thresh
|
||||
int tp = 0, fn = 0, tn = 0, fp = 0;
|
||||
for (size_t i = 0; i < num_samples; i++) {
|
||||
if (label_casted_data[i]) {
|
||||
if (inference_data[i] >= (thresholds_list[idx_thresh])) {
|
||||
tp++;
|
||||
} else {
|
||||
fn++;
|
||||
}
|
||||
} else {
|
||||
if (inference_data[i] >= (thresholds_list[idx_thresh])) {
|
||||
fp++;
|
||||
} else {
|
||||
tn++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store rates
|
||||
tp_data[idx_thresh] = tp;
|
||||
fn_data[idx_thresh] = fn;
|
||||
tn_data[idx_thresh] = tn;
|
||||
fp_data[idx_thresh] = fp;
|
||||
}
|
||||
// epsilon to avoid divide by zero.
|
||||
float epsilon = 1e-6;
|
||||
// Riemann sum to caculate auc.
|
||||
Tensor tp_rate, fp_rate, rec_rate;
|
||||
tp_rate.Resize({num_thresholds});
|
||||
fp_rate.Resize({num_thresholds});
|
||||
rec_rate.Resize({num_thresholds});
|
||||
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace());
|
||||
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace());
|
||||
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace());
|
||||
for (int i = 0; i < num_thresholds; i++) {
|
||||
tp_rate_data[i] =
|
||||
((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon);
|
||||
fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon);
|
||||
rec_rate_data[i] =
|
||||
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon);
|
||||
}
|
||||
*auc_data = 0.0f;
|
||||
if (curve == "ROC") {
|
||||
for (int i = 0; i < num_thresholds - 1; i++) {
|
||||
auto dx = fp_rate_data[i] - fp_rate_data[i + 1];
|
||||
auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f;
|
||||
*auc_data = *auc_data + dx * y;
|
||||
}
|
||||
} else if (curve == "PR") {
|
||||
for (int i = 1; i < num_thresholds; i++) {
|
||||
auto dx = tp_rate_data[i] - tp_rate_data[i - 1];
|
||||
auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f;
|
||||
*auc_data = *auc_data + dx * y;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,122 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/huber_loss_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class HuberLossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
|
||||
PADDLE_ENFORCE_EQ(x_dims, y_dims);
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
|
||||
"The rank of Input(X) must be 2 and the shape is "
|
||||
"[batch_size, 1].");
|
||||
PADDLE_ENFORCE_EQ(x_dims[1], 1,
|
||||
"Each row of Input(X) contains a real value, "
|
||||
"so the 2nd dimension of Input(X) must be 1.");
|
||||
|
||||
ctx->SetOutputDim("Residual", x_dims);
|
||||
ctx->SetOutputDim("Out", {x_dims[0], 1});
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
HuberLossOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"The input value of huber loss op."
|
||||
"X is a 2-D tensor with shape [batch_size, 1].");
|
||||
AddInput("Y",
|
||||
"The target value of huber loss op."
|
||||
"Y is a 2-D tensor with shape [batch_size, 1].");
|
||||
AddOutput("Residual",
|
||||
"Intermediate tensor to cache residual value between Y and X."
|
||||
"The shape is same as Input(X) and will be reused in backward.")
|
||||
.AsIntermediate();
|
||||
AddOutput("Out",
|
||||
"The output tensor with shape [batch_size, 1] which represents "
|
||||
"the huber loss.");
|
||||
AddAttr<AttrType>("delta", "Hyper parameter in huber loss.");
|
||||
AddComment(R"DOC(
|
||||
Huber loss is a loss function used in robust regression. We define X as the
|
||||
input value and Y as the target value. Huber loss can evaluate the fitness of
|
||||
X to Y. Different from MSE loss, Huber loss is more robust for outliers. The
|
||||
shape of X and Y are [batch_size, 1]. The equation is:
|
||||
|
||||
L_{\delta}(y, f(x)) =
|
||||
\begin{cases}
|
||||
0.5 * (y - f(x))^2, \quad |y - f(x)| \leq \delta \\
|
||||
\delta * (|y - f(x)| - 0.5 * \delta), \quad otherwise
|
||||
\end{cases}
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class HuberLossGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Residual"),
|
||||
"Input(Residual) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
auto residual_dims = ctx->GetInputDim("Residual");
|
||||
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
||||
|
||||
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
|
||||
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
|
||||
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
if (ctx->HasOutput(y_grad_name)) {
|
||||
ctx->SetOutputDim(y_grad_name, y_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
|
||||
huber_loss_grad, ops::HuberLossGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(huber_loss,
|
||||
ops::HuberLossKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
huber_loss_grad,
|
||||
ops::HuberLossGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/huber_loss_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(huber_loss,
|
||||
ops::HuberLossKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
huber_loss_grad,
|
||||
ops::HuberLossGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,119 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
struct HuberLossForward {
|
||||
HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {}
|
||||
|
||||
HOSTDEVICE T operator()(const T& val) const {
|
||||
T abs_val = std::abs(val);
|
||||
if (abs_val <= delta) {
|
||||
return static_cast<T>(0.5) * val * val;
|
||||
} else {
|
||||
return delta * (abs_val - static_cast<T>(0.5) * delta);
|
||||
}
|
||||
}
|
||||
|
||||
T delta;
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class HuberLossKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in0 = context.Input<Tensor>("X");
|
||||
auto* in1 = context.Input<Tensor>("Y");
|
||||
auto* out0 = context.Output<Tensor>("Residual");
|
||||
auto* out1 = context.Output<Tensor>("Out");
|
||||
auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
auto x = EigenVector<T>::Flatten(*in0);
|
||||
auto y = EigenVector<T>::Flatten(*in1);
|
||||
out0->mutable_data<T>(context.GetPlace());
|
||||
auto residual = EigenVector<T>::Flatten(*out0);
|
||||
residual.device(place) = y - x;
|
||||
out1->mutable_data<T>(context.GetPlace());
|
||||
auto loss = EigenVector<T>::Flatten(*out1);
|
||||
loss.device(place) = residual.unaryExpr(HuberLossForward<T>(delta));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct HuberLossBackward {
|
||||
HOSTDEVICE HuberLossBackward(const T& delta, T sign)
|
||||
: sign(sign), delta(delta) {}
|
||||
|
||||
HOSTDEVICE T operator()(const T& val) const {
|
||||
T abs_val = std::abs(val);
|
||||
if (abs_val <= delta) {
|
||||
return sign * val;
|
||||
} else {
|
||||
if (val > 0) {
|
||||
return sign * delta;
|
||||
} else {
|
||||
return -1 * sign * delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
T sign;
|
||||
T delta;
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class HuberLossGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in0 = context.Input<Tensor>("Residual");
|
||||
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
auto residual = EigenVector<T>::Flatten(*in0);
|
||||
auto out_grad = EigenVector<T>::Flatten(*in1);
|
||||
|
||||
if (out0) {
|
||||
out0->mutable_data<T>(context.GetPlace());
|
||||
auto x_grad = EigenVector<T>::Flatten(*out0);
|
||||
x_grad.device(place) =
|
||||
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
|
||||
}
|
||||
|
||||
if (out1) {
|
||||
out1->mutable_data<T>(context.GetPlace());
|
||||
auto y_grad = EigenVector<T>::Flatten(*out1);
|
||||
y_grad.device(place) =
|
||||
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,66 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestAucOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "auc"
|
||||
pred = np.random.random((128)).astype("float32")
|
||||
labels = np.random.randint(0, 2, (128, ))
|
||||
num_thresholds = 200
|
||||
self.inputs = {'Inference': pred, 'Label': labels}
|
||||
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
|
||||
# NOTE: sklearn use a different way to generate thresholds
|
||||
# which will cause the result differs slightly:
|
||||
# from sklearn.metrics import roc_curve, auc
|
||||
# fpr, tpr, thresholds = roc_curve(labels, pred)
|
||||
# auc_value = auc(fpr, tpr)
|
||||
# we caculate AUC again using numpy for testing
|
||||
kepsilon = 1e-7 # to account for floating point imprecisions
|
||||
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds - 2)]
|
||||
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
|
||||
|
||||
# caculate TP, FN, TN, FP count
|
||||
tp_list = np.ndarray((num_thresholds, ))
|
||||
fn_list = np.ndarray((num_thresholds, ))
|
||||
tn_list = np.ndarray((num_thresholds, ))
|
||||
fp_list = np.ndarray((num_thresholds, ))
|
||||
for idx_thresh, thresh in enumerate(thresholds):
|
||||
tp, fn, tn, fp = 0, 0, 0, 0
|
||||
for i, lbl in enumerate(labels):
|
||||
if lbl:
|
||||
if pred[i] >= thresh:
|
||||
tp += 1
|
||||
else:
|
||||
fn += 1
|
||||
else:
|
||||
if pred[i] >= thresh:
|
||||
fp += 1
|
||||
else:
|
||||
tn += 1
|
||||
tp_list[idx_thresh] = tp
|
||||
fn_list[idx_thresh] = fn
|
||||
tn_list[idx_thresh] = tn
|
||||
fp_list[idx_thresh] = fp
|
||||
|
||||
epsilon = 1e-6
|
||||
tpr = (tp_list.astype("float32") + epsilon) / (
|
||||
tp_list + fn_list + epsilon)
|
||||
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon)
|
||||
rec = (tp_list.astype("float32") + epsilon) / (
|
||||
tp_list + fp_list + epsilon)
|
||||
|
||||
x = fpr[:num_thresholds - 1] - fpr[1:]
|
||||
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
|
||||
auc_value = np.sum(x * y)
|
||||
|
||||
self.outputs = {'AUC': auc_value}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,47 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def huber_loss_forward(val, delta):
|
||||
abs_val = abs(val)
|
||||
if abs_val <= delta:
|
||||
return 0.5 * val * val
|
||||
else:
|
||||
return delta * (abs_val - 0.5 * delta)
|
||||
|
||||
|
||||
class TestHuberLossOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'huber_loss'
|
||||
samples_num = 64
|
||||
delta = 1.0
|
||||
self.inputs = {
|
||||
'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
|
||||
'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
|
||||
}
|
||||
residual = self.inputs['Y'] - self.inputs['X']
|
||||
loss = np.vectorize(huber_loss_forward)(residual, delta)
|
||||
self.attrs = {'delta': delta}
|
||||
self.outputs = {
|
||||
'Residual': residual,
|
||||
'Out': loss.reshape((samples_num, 1))
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008)
|
||||
|
||||
def test_check_grad_ingore_x(self):
|
||||
self.check_grad(
|
||||
['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual"))
|
||||
|
||||
def test_check_grad_ingore_y(self):
|
||||
self.check_grad(
|
||||
['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue