|
|
|
@ -243,6 +243,7 @@ class Precision(MetricBase):
|
|
|
|
raise ValueError("The 'labels' must be a numpy ndarray.")
|
|
|
|
raise ValueError("The 'labels' must be a numpy ndarray.")
|
|
|
|
sample_num = labels.shape[0]
|
|
|
|
sample_num = labels.shape[0]
|
|
|
|
preds = np.rint(preds).astype("int32")
|
|
|
|
preds = np.rint(preds).astype("int32")
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(sample_num):
|
|
|
|
for i in range(sample_num):
|
|
|
|
pred = preds[i]
|
|
|
|
pred = preds[i]
|
|
|
|
label = labels[i]
|
|
|
|
label = labels[i]
|
|
|
|
|