|
|
|
@ -23,7 +23,7 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class AccuracyKernel : public framework::OpKernel {
|
|
|
|
|
class AucKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* inference = ctx.Input<Tensor>("Inference");
|
|
|
|
@ -45,7 +45,7 @@ class AccuracyKernel : public framework::OpKernel {
|
|
|
|
|
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
|
|
|
|
|
|
|
|
|
|
const int* inference_data = inference->data<int>();
|
|
|
|
|
const T* inference_prob_data = inference->data<T>();
|
|
|
|
|
const T* inference_prob_data = inference_prob->data<T>();
|
|
|
|
|
const T* label_data = label->data<T>();
|
|
|
|
|
|
|
|
|
|
size_t num_samples = inference->dims()[0];
|
|
|
|
@ -54,17 +54,17 @@ class AccuracyKernel : public framework::OpKernel {
|
|
|
|
|
// create local tensor for storing the curve: TP, FN, TN, FP
|
|
|
|
|
// TODO(typhoonzero): put these tensors in Scope
|
|
|
|
|
// TODO(typhoonzero): use op to caculate these values.
|
|
|
|
|
Tensor true_positive, false_positeve, true_negative, false_negative;
|
|
|
|
|
Tensor true_positive, false_positive, true_negative, false_negative;
|
|
|
|
|
|
|
|
|
|
true_positive.Resize({num_thresholds});
|
|
|
|
|
false_negative.Resize({num_thresholds});
|
|
|
|
|
true_negative.Resize({num_thresholds});
|
|
|
|
|
false_positive.Resize({num_thresholds});
|
|
|
|
|
|
|
|
|
|
int* tp_data = true_positive.mutable_data<int>();
|
|
|
|
|
int* fn_data = false_negative.mutable_data<int>();
|
|
|
|
|
int* tn_data = true_negative.mutable_data<int>();
|
|
|
|
|
int* fp_data = false_positive.mutable_data<int>();
|
|
|
|
|
int* tp_data = true_positive.mutable_data<int>(ctx.GetPlace());
|
|
|
|
|
int* fn_data = false_negative.mutable_data<int>(ctx.GetPlace());
|
|
|
|
|
int* tn_data = true_negative.mutable_data<int>(ctx.GetPlace());
|
|
|
|
|
int* fp_data = false_positive.mutable_data<int>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end();
|
|
|
|
|
thresh++) {
|
|
|
|
@ -101,15 +101,15 @@ class AccuracyKernel : public framework::OpKernel {
|
|
|
|
|
tp_rate.Resize({num_thresholds});
|
|
|
|
|
fp_rate.Resize({num_thresholds});
|
|
|
|
|
rec_rate.Resize({num_thresholds});
|
|
|
|
|
float* tp_rate_data = tp_rate.mutable_data<float>();
|
|
|
|
|
float* fp_rate_data = fp_rate.mutable_data<float>();
|
|
|
|
|
float* rec_rate_data = rec_rate.mutable_data<float>();
|
|
|
|
|
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace());
|
|
|
|
|
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace());
|
|
|
|
|
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace());
|
|
|
|
|
for (int i = 0; i < num_thresholds; i++) {
|
|
|
|
|
tp_rate_data[i] = ((float)tp_data[i + epsilon) / (tp_data[i] + fn_data[i] + epsilon);
|
|
|
|
|
fp_rate_data[i] =
|
|
|
|
|
(float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon);
|
|
|
|
|
rec_rate_data[i] =
|
|
|
|
|
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon);
|
|
|
|
|
tp_rate_data[i] =
|
|
|
|
|
((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon);
|
|
|
|
|
fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon);
|
|
|
|
|
rec_rate_data[i] =
|
|
|
|
|
((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (curve == "ROC") {
|
|
|
|
@ -118,7 +118,7 @@ class AccuracyKernel : public framework::OpKernel {
|
|
|
|
|
auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f;
|
|
|
|
|
*auc_data = *auc_data + dx * y;
|
|
|
|
|
}
|
|
|
|
|
} else if (curve = "PR") {
|
|
|
|
|
} else if (curve == "PR") {
|
|
|
|
|
for (int i = 1; i < num_thresholds; i++) {
|
|
|
|
|
auto dx = tp_rate_data[i] - tp_rate_data[i - 1];
|
|
|
|
|
auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f;
|
|
|
|
|