You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/demo/mnist/api_train.py

43 lines
1.3 KiB

8 years ago
import py_paddle.swig_paddle as api
import paddle.trainer.config_parser
import numpy as np
def init_parameter(network):
assert isinstance(network, api.GradientMachine)
for each_param in network.getParameters():
assert isinstance(each_param, api.Parameter)
array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
assert isinstance(array, np.ndarray)
for i in xrange(len(array)):
array[i] = np.random.uniform(-1.0, 1.0)
8 years ago
def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
config = paddle.trainer.config_parser.parse_config(
'simple_mnist_network.py', '')
opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()
m = api.GradientMachine.createFromConfigProto(
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
assert isinstance(m, api.GradientMachine)
init_parameter(network=m)
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater)
updater.init(m)
m.start()
for _ in xrange(100):
updater.startPass()
m.finish()
8 years ago
if __name__ == '__main__':
main()