|
|
|
@ -325,14 +325,14 @@ class Auc(MetricBase):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, name, curve='ROC', num_thresholds=200):
|
|
|
|
|
super(MetricBase, self).__init__(name, curve, num_thresholds)
|
|
|
|
|
super(Auc, self).__init__(name=name)
|
|
|
|
|
self._curve = curve
|
|
|
|
|
self._num_thresholds = num_thresholds
|
|
|
|
|
self._epsilon = 1e-6
|
|
|
|
|
self.tp_list = np.ndarray((num_thresholds, ))
|
|
|
|
|
self.fn_list = np.ndarray((num_thresholds, ))
|
|
|
|
|
self.tn_list = np.ndarray((num_thresholds, ))
|
|
|
|
|
self.fp_list = np.ndarray((num_thresholds, ))
|
|
|
|
|
self.tp_list = np.zeros((num_thresholds, ))
|
|
|
|
|
self.fn_list = np.zeros((num_thresholds, ))
|
|
|
|
|
self.tn_list = np.zeros((num_thresholds, ))
|
|
|
|
|
self.fp_list = np.zeros((num_thresholds, ))
|
|
|
|
|
|
|
|
|
|
def update(self, labels, predictions, axis=1):
|
|
|
|
|
if not _is_numpy_(labels):
|
|
|
|
@ -350,12 +350,12 @@ class Auc(MetricBase):
|
|
|
|
|
tp, fn, tn, fp = 0, 0, 0, 0
|
|
|
|
|
for i, lbl in enumerate(labels):
|
|
|
|
|
if lbl:
|
|
|
|
|
if predictions[i, 0] >= thresh:
|
|
|
|
|
if predictions[i, 1] >= thresh:
|
|
|
|
|
tp += 1
|
|
|
|
|
else:
|
|
|
|
|
fn += 1
|
|
|
|
|
else:
|
|
|
|
|
if predictions[i, 0] >= thresh:
|
|
|
|
|
if predictions[i, 1] >= thresh:
|
|
|
|
|
fp += 1
|
|
|
|
|
else:
|
|
|
|
|
tn += 1
|
|
|
|
|