|
|
|
@ -157,6 +157,7 @@ def main():
|
|
|
|
|
updater.finishBatch(cost)
|
|
|
|
|
|
|
|
|
|
# testing stage. use test data set to test current network.
|
|
|
|
|
updater.apply()
|
|
|
|
|
test_evaluator.start()
|
|
|
|
|
test_data_generator = input_order_converter(read_from_mnist(test_file))
|
|
|
|
|
for data_batch in generator_to_batch(test_data_generator, 128):
|
|
|
|
@ -167,6 +168,18 @@ def main():
|
|
|
|
|
# print error rate for test data set
|
|
|
|
|
print 'Pass', pass_id, ' test evaluator: ', test_evaluator
|
|
|
|
|
test_evaluator.finish()
|
|
|
|
|
updater.restore()
|
|
|
|
|
|
|
|
|
|
updater.catchUpWith()
|
|
|
|
|
params = m.getParameters()
|
|
|
|
|
for each_param in params:
|
|
|
|
|
assert isinstance(each_param, api.Parameter)
|
|
|
|
|
value = each_param.getBuf(api.PARAMETER_VALUE)
|
|
|
|
|
value = value.toNumpyArrayInplace()
|
|
|
|
|
|
|
|
|
|
# Here, we could save parameter to every where you want
|
|
|
|
|
print each_param.getName(), value
|
|
|
|
|
|
|
|
|
|
updater.finishPass()
|
|
|
|
|
|
|
|
|
|
m.finish()
|
|
|
|
|