|
|
@ -49,7 +49,7 @@ for pass_id in range(PASS_NUM):
|
|
|
|
avg_loss_value, = exe.run(fluid.default_main_program(),
|
|
|
|
avg_loss_value, = exe.run(fluid.default_main_program(),
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
fetch_list=[avg_cost])
|
|
|
|
fetch_list=[avg_cost])
|
|
|
|
|
|
|
|
print(avg_loss_value)
|
|
|
|
if avg_loss_value[0] < 10.0:
|
|
|
|
if avg_loss_value[0] < 10.0:
|
|
|
|
exit(0) # if avg cost less than 10.0, we think our code is good.
|
|
|
|
exit(0) # if avg cost less than 10.0, we think our code is good.
|
|
|
|
exit(1)
|
|
|
|
exit(1)
|
|
|
|