|
|
|
@ -45,7 +45,6 @@ def main():
|
|
|
|
|
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
|
|
|
|
|
assert isinstance(m, api.GradientMachine)
|
|
|
|
|
init_parameter(network=m)
|
|
|
|
|
|
|
|
|
|
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
|
|
|
|
|
assert isinstance(updater, api.ParameterUpdater)
|
|
|
|
|
updater.init(m)
|
|
|
|
@ -62,7 +61,7 @@ def main():
|
|
|
|
|
train_data_generator = input_order_converter(
|
|
|
|
|
read_from_mnist(train_file))
|
|
|
|
|
for data_batch in generator_to_batch(train_data_generator, 128):
|
|
|
|
|
inArgs = converter(data_batch)
|
|
|
|
|
trainRole = updater.startBatch(len(data_batch))
|
|
|
|
|
|
|
|
|
|
updater.finishPass()
|
|
|
|
|
|
|
|
|
|