@ -227,7 +227,7 @@ class Precision(MetricBase):
metric . reset ( )
metric . reset ( )
for data in train_reader ( ) :
for data in train_reader ( ) :
loss , preds , labels = exe . run ( fetch_list = [ cost , preds , labels ] )
loss , preds , labels = exe . run ( fetch_list = [ cost , preds , labels ] )
metric . update ( preds = preds , labels = labels )
metric . update ( preds = preds , labels = labels )
numpy_precision = metric . eval ( )
numpy_precision = metric . eval ( )
"""
"""
@ -241,9 +241,10 @@ class Precision(MetricBase):
raise ValueError ( " The ' preds ' must be a numpy ndarray. " )
raise ValueError ( " The ' preds ' must be a numpy ndarray. " )
if not _is_numpy_ ( labels ) :
if not _is_numpy_ ( labels ) :
raise ValueError ( " The ' labels ' must be a numpy ndarray. " )
raise ValueError ( " The ' labels ' must be a numpy ndarray. " )
sample_num = labels [ 0 ]
sample_num = labels . shape [ 0 ]
preds = np . rint ( preds ) . astype ( " int32 " )
for i in range ( sample_num ) :
for i in range ( sample_num ) :
pred = preds [ i ] . astype ( " int32 " )
pred = preds [ i ]
label = labels [ i ]
label = labels [ i ]
if label == 1 :
if label == 1 :
if pred == label :
if pred == label :