|
|
|
@ -31,7 +31,7 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class AucKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* inference = ctx.Input<Tensor>("Out");
|
|
|
|
|
auto* predict = ctx.Input<Tensor>("Predict");
|
|
|
|
|
auto* label = ctx.Input<Tensor>("Label");
|
|
|
|
|
auto* auc = ctx.Output<Tensor>("AUC");
|
|
|
|
|
// Only use output var for now, make sure it's persistable and
|
|
|
|
@ -41,24 +41,24 @@ class AucKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* true_negative = ctx.Output<Tensor>("TNOut");
|
|
|
|
|
auto* false_negative = ctx.Output<Tensor>("FNOut");
|
|
|
|
|
|
|
|
|
|
float* auc_data = auc->mutable_data<float>(ctx.GetPlace());
|
|
|
|
|
auto* auc_data = auc->mutable_data<double>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
std::string curve = ctx.Attr<std::string>("curve");
|
|
|
|
|
int num_thresholds = ctx.Attr<int>("num_thresholds");
|
|
|
|
|
std::vector<float> thresholds_list;
|
|
|
|
|
std::vector<double> thresholds_list;
|
|
|
|
|
thresholds_list.reserve(num_thresholds);
|
|
|
|
|
for (int i = 1; i < num_thresholds - 1; i++) {
|
|
|
|
|
thresholds_list[i] = static_cast<float>(i) / (num_thresholds - 1);
|
|
|
|
|
thresholds_list[i] = static_cast<double>(i) / (num_thresholds - 1);
|
|
|
|
|
}
|
|
|
|
|
const float kEpsilon = 1e-7;
|
|
|
|
|
const double kEpsilon = 1e-7;
|
|
|
|
|
thresholds_list[0] = 0.0f - kEpsilon;
|
|
|
|
|
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
|
|
|
|
|
|
|
|
|
|
size_t batch_size = inference->dims()[0];
|
|
|
|
|
size_t inference_width = inference->dims()[1];
|
|
|
|
|
size_t batch_size = predict->dims()[0];
|
|
|
|
|
size_t inference_width = predict->dims()[1];
|
|
|
|
|
|
|
|
|
|
const T* inference_data = inference->data<T>();
|
|
|
|
|
const int64_t* label_data = label->data<int64_t>();
|
|
|
|
|
const T* inference_data = predict->data<T>();
|
|
|
|
|
const auto* label_data = label->data<int64_t>();
|
|
|
|
|
|
|
|
|
|
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
@ -66,20 +66,19 @@ class AucKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
|
|
|
|
|
// caculate TP, FN, TN, FP for current thresh
|
|
|
|
|
// calculate TP, FN, TN, FP for current thresh
|
|
|
|
|
int64_t tp = 0, fn = 0, tn = 0, fp = 0;
|
|
|
|
|
for (size_t i = 0; i < batch_size; i++) {
|
|
|
|
|
// NOTE: label_data used as bool, labels >0 will be treated as true.
|
|
|
|
|
// NOTE: label_data used as bool, labels > 0 will be treated as true.
|
|
|
|
|
if (label_data[i]) {
|
|
|
|
|
// use first(max) data in each row
|
|
|
|
|
if (inference_data[i * inference_width] >=
|
|
|
|
|
if (inference_data[i * inference_width + 1] >=
|
|
|
|
|
(thresholds_list[idx_thresh])) {
|
|
|
|
|
tp++;
|
|
|
|
|
} else {
|
|
|
|
|
fn++;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (inference_data[i * inference_width] >=
|
|
|
|
|
if (inference_data[i * inference_width + 1] >=
|
|
|
|
|
(thresholds_list[idx_thresh])) {
|
|
|
|
|
fp++;
|
|
|
|
|
} else {
|
|
|
|
@ -94,21 +93,21 @@ class AucKernel : public framework::OpKernel<T> {
|
|
|
|
|
fp_data[idx_thresh] += fp;
|
|
|
|
|
}
|
|
|
|
|
// epsilon to avoid divide by zero.
|
|
|
|
|
float epsilon = 1e-6;
|
|
|
|
|
double epsilon = 1e-6;
|
|
|
|
|
// Riemann sum to caculate auc.
|
|
|
|
|
Tensor tp_rate, fp_rate, rec_rate;
|
|
|
|
|
tp_rate.Resize({num_thresholds});
|
|
|
|
|
fp_rate.Resize({num_thresholds});
|
|
|
|
|
rec_rate.Resize({num_thresholds});
|
|
|
|
|
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace());
|
|
|
|
|
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace());
|
|
|
|
|
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace());
|
|
|
|
|
auto* tp_rate_data = tp_rate.mutable_data<double>(ctx.GetPlace());
|
|
|
|
|
auto* fp_rate_data = fp_rate.mutable_data<double>(ctx.GetPlace());
|
|
|
|
|
auto* rec_rate_data = rec_rate.mutable_data<double>(ctx.GetPlace());
|
|
|
|
|
for (int i = 0; i < num_thresholds; i++) {
|
|
|
|
|
tp_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) /
|
|
|
|
|
tp_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
|
|
|
|
|
(tp_data[i] + fn_data[i] + epsilon);
|
|
|
|
|
fp_rate_data[i] =
|
|
|
|
|
static_cast<float>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
|
|
|
|
|
rec_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) /
|
|
|
|
|
static_cast<double>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
|
|
|
|
|
rec_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
|
|
|
|
|
(tp_data[i] + fp_data[i] + epsilon);
|
|
|
|
|
}
|
|
|
|
|
*auc_data = 0.0f;
|
|
|
|
|