|
|
@ -60,20 +60,6 @@ class AucKernel : public framework::OpKernel<T> {
|
|
|
|
const T* inference_data = predict->data<T>();
|
|
|
|
const T* inference_data = predict->data<T>();
|
|
|
|
const auto* label_data = label->data<int64_t>();
|
|
|
|
const auto* label_data = label->data<int64_t>();
|
|
|
|
|
|
|
|
|
|
|
|
// check if states are inited.
|
|
|
|
|
|
|
|
auto* tp_in = ctx.Input<Tensor>("TP");
|
|
|
|
|
|
|
|
auto* fp_in = ctx.Input<Tensor>("FP");
|
|
|
|
|
|
|
|
auto* tn_in = ctx.Input<Tensor>("TN");
|
|
|
|
|
|
|
|
auto* fn_in = ctx.Input<Tensor>("FN");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(tp_in->IsInitialized(), "true_positive is not inited!");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(fp_in->IsInitialized(), "false_negative is not inited!");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(tn_in->IsInitialized(), "true_negative is not inited!");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(fn_in->IsInitialized(), "false_positive is not inited!");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(tp_in->numel(), num_thresholds, "");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(fp_in->numel(), num_thresholds, "");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(tn_in->numel(), num_thresholds, "");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(fn_in->numel(), num_thresholds, "");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|