|
|
|
@ -227,7 +227,7 @@ class Precision(MetricBase):
|
|
|
|
|
metric.reset()
|
|
|
|
|
for data in train_reader():
|
|
|
|
|
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
|
|
|
|
|
metric.update(preds=preds, labels=labels)
|
|
|
|
|
metric.update(preds=preds, labels=labels)
|
|
|
|
|
numpy_precision = metric.eval()
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -241,9 +241,11 @@ class Precision(MetricBase):
|
|
|
|
|
raise ValueError("The 'preds' must be a numpy ndarray.")
|
|
|
|
|
if not _is_numpy_(labels):
|
|
|
|
|
raise ValueError("The 'labels' must be a numpy ndarray.")
|
|
|
|
|
sample_num = labels[0]
|
|
|
|
|
sample_num = labels.shape[0]
|
|
|
|
|
preds = np.rint(preds).astype("int32")
|
|
|
|
|
|
|
|
|
|
for i in range(sample_num):
|
|
|
|
|
pred = preds[i].astype("int32")
|
|
|
|
|
pred = preds[i]
|
|
|
|
|
label = labels[i]
|
|
|
|
|
if label == 1:
|
|
|
|
|
if pred == label:
|
|
|
|
|