|
|
|
@ -58,11 +58,25 @@ def main():
|
|
|
|
|
|
|
|
|
|
for _ in xrange(100):
|
|
|
|
|
updater.startPass()
|
|
|
|
|
outArgs = api.Arguments.createArguments(0)
|
|
|
|
|
train_data_generator = input_order_converter(
|
|
|
|
|
read_from_mnist(train_file))
|
|
|
|
|
for data_batch in generator_to_batch(train_data_generator, 128):
|
|
|
|
|
for batch_id, data_batch in enumerate(
|
|
|
|
|
generator_to_batch(train_data_generator, 256)):
|
|
|
|
|
trainRole = updater.startBatch(len(data_batch))
|
|
|
|
|
|
|
|
|
|
def update_callback(param):
|
|
|
|
|
updater.update(param)
|
|
|
|
|
|
|
|
|
|
m.forwardBackward(
|
|
|
|
|
converter(data_batch), outArgs, trainRole, update_callback)
|
|
|
|
|
|
|
|
|
|
cost_vec = outArgs.getSlotValue(0)
|
|
|
|
|
cost_vec = cost_vec.copyToNumpyMat()
|
|
|
|
|
cost = cost_vec.sum() / len(data_batch)
|
|
|
|
|
print 'Batch id', batch_id, 'with cost=', cost
|
|
|
|
|
updater.finishBatch(cost)
|
|
|
|
|
|
|
|
|
|
updater.finishPass()
|
|
|
|
|
|
|
|
|
|
m.finish()
|
|
|
|
|