You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.3 KiB
75 lines
2.3 KiB
from paddle.trainer_config_helpers import *
|
|
from paddle.trainer.PyDataProvider2 import dense_vector, integer_value
|
|
import paddle.v2 as paddle_v2
|
|
import numpy
|
|
import mnist_util
|
|
|
|
|
|
def train_reader():
|
|
train_file = './data/raw_data/train'
|
|
generator = mnist_util.read_from_mnist(train_file)
|
|
for item in generator:
|
|
yield item
|
|
|
|
|
|
def network_config():
|
|
imgs = data_layer(name='pixel', size=784)
|
|
hidden1 = fc_layer(input=imgs, size=200)
|
|
hidden2 = fc_layer(input=hidden1, size=200)
|
|
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
|
|
cost = classification_cost(
|
|
input=inference, label=data_layer(
|
|
name='label', size=10))
|
|
outputs(cost)
|
|
|
|
|
|
def event_handler(event):
|
|
if isinstance(event, paddle_v2.trainer.CompleteTrainOneBatch):
|
|
print "Pass %d, Batch %d, Cost %f" % (event.pass_id, event.batch_id,
|
|
event.cost)
|
|
else:
|
|
pass
|
|
|
|
|
|
def main():
|
|
paddle_v2.init(use_gpu=False, trainer_count=1)
|
|
model_config = parse_network_config(network_config)
|
|
pool = paddle_v2.parameters.create(model_config)
|
|
for param_name in pool.get_names():
|
|
array = pool.get_parameter(param_name)
|
|
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
|
|
|
|
def nag(v, g, vel_t_1):
|
|
"""
|
|
NAG Optimizer. A optimizer which Paddle CPP is not implemented.
|
|
https://arxiv.org/pdf/1212.0901v2.pdf eq.6 eq.7
|
|
:param v: parameter value
|
|
:param g: parameter gradient
|
|
:param vel_t_1: t-1 velocity
|
|
:return:
|
|
"""
|
|
mu = 0.09
|
|
e = 0.0001
|
|
|
|
vel_t = mu * vel_t_1 - e * g
|
|
|
|
v[:] = v + (mu**2) * vel_t - (1 + mu) * e * g
|
|
vel_t_1[:] = vel_t
|
|
|
|
trainer = paddle_v2.trainer.SGDTrainer(update_equation=nag)
|
|
|
|
trainer.train(train_data_reader=train_reader,
|
|
topology=model_config,
|
|
parameters=pool,
|
|
event_handler=event_handler,
|
|
batch_size=32, # batch size should be refactor in Data reader
|
|
data_types={ # data_types will be removed, It should be in
|
|
# network topology
|
|
'pixel': dense_vector(784),
|
|
'label': integer_value(10)
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|