revert-11610-move_hooks
qiaolongfei 7 years ago
parent 4ec9ecae59
commit 4aa5da0550

@ -325,14 +325,14 @@ class Auc(MetricBase):
"""
def __init__(self, name, curve='ROC', num_thresholds=200):
super(MetricBase, self).__init__(name, curve, num_thresholds)
super(Auc, self).__init__(name=name)
self._curve = curve
self._num_thresholds = num_thresholds
self._epsilon = 1e-6
self.tp_list = np.ndarray((num_thresholds, ))
self.fn_list = np.ndarray((num_thresholds, ))
self.tn_list = np.ndarray((num_thresholds, ))
self.fp_list = np.ndarray((num_thresholds, ))
self.tp_list = np.zeros((num_thresholds, ))
self.fn_list = np.zeros((num_thresholds, ))
self.tn_list = np.zeros((num_thresholds, ))
self.fp_list = np.zeros((num_thresholds, ))
def update(self, labels, predictions, axis=1):
if not _is_numpy_(labels):
@ -350,12 +350,12 @@ class Auc(MetricBase):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if predictions[i, 0] >= thresh:
if predictions[i, 1] >= thresh:
tp += 1
else:
fn += 1
else:
if predictions[i, 0] >= thresh:
if predictions[i, 1] >= thresh:
fp += 1
else:
tn += 1

Loading…
Cancel
Save