'add f1 test'

mobile_baidu
Dong Zhihong 7 years ago
parent 8d9b33412d
commit c4ac7fab5e

@ -121,18 +121,14 @@ class Accuracy(Evaluator):
return executor.run(eval_program, fetch_list=[eval_out])
# This is demo for composing low level op to compute metric
# Demo for composing low level op to compute the F1 metric
class F1(Evaluator):
def __init__(self, input, label, **kwargs):
super(F1, self).__init__("F1", **kwargs)
super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable(
name=unique_name("Total"),
persistable=True,
dtype="int64",
shape=[1])
g_correct = helper.create_global_variable(
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
g_tp = helper.create_global_variable(
name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1])
g_fp = helper.create_global_variable(
name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1])
self._states["Tp"] = g_tp
self._states["Fp"] = g_fp

@ -61,6 +61,7 @@ PASS_NUM = 100
for pass_id in range(PASS_NUM):
save_persistables(exe, "./fit_a_line.model/", main_program=main_program)
load_persistables(exe, "./fit_a_line.model/", main_program=main_program)
accuracy.reset(exe)
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32")
@ -75,8 +76,10 @@ for pass_id in range(PASS_NUM):
outs = exe.run(main_program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
fetch_list=[avg_cost, accuracy])
out = np.array(outs[0])
pass_acc = accuracy.eval(exe)
print pass_acc
if out[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good.

Loading…
Cancel
Save