|
|
@ -580,10 +580,10 @@ class Auc(MetricBase):
|
|
|
|
self.tn_list = np.zeros((num_thresholds, ))
|
|
|
|
self.tn_list = np.zeros((num_thresholds, ))
|
|
|
|
self.fp_list = np.zeros((num_thresholds, ))
|
|
|
|
self.fp_list = np.zeros((num_thresholds, ))
|
|
|
|
|
|
|
|
|
|
|
|
def update(self, predictions, labels):
|
|
|
|
def update(self, preds, labels):
|
|
|
|
if not _is_numpy_(labels):
|
|
|
|
if not _is_numpy_(labels):
|
|
|
|
raise ValueError("The 'labels' must be a numpy ndarray.")
|
|
|
|
raise ValueError("The 'labels' must be a numpy ndarray.")
|
|
|
|
if not _is_numpy_(predictions):
|
|
|
|
if not _is_numpy_(preds):
|
|
|
|
raise ValueError("The 'predictions' must be a numpy ndarray.")
|
|
|
|
raise ValueError("The 'predictions' must be a numpy ndarray.")
|
|
|
|
|
|
|
|
|
|
|
|
kepsilon = 1e-7 # to account for floating point imprecisions
|
|
|
|
kepsilon = 1e-7 # to account for floating point imprecisions
|
|
|
@ -596,12 +596,12 @@ class Auc(MetricBase):
|
|
|
|
tp, fn, tn, fp = 0, 0, 0, 0
|
|
|
|
tp, fn, tn, fp = 0, 0, 0, 0
|
|
|
|
for i, lbl in enumerate(labels):
|
|
|
|
for i, lbl in enumerate(labels):
|
|
|
|
if lbl:
|
|
|
|
if lbl:
|
|
|
|
if predictions[i, 1] >= thresh:
|
|
|
|
if preds[i, 1] >= thresh:
|
|
|
|
tp += 1
|
|
|
|
tp += 1
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
fn += 1
|
|
|
|
fn += 1
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if predictions[i, 1] >= thresh:
|
|
|
|
if preds[i, 1] >= thresh:
|
|
|
|
fp += 1
|
|
|
|
fp += 1
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tn += 1
|
|
|
|
tn += 1
|
|
|
|