|
|
|
@ -14,13 +14,15 @@
|
|
|
|
|
|
|
|
|
|
import gzip
|
|
|
|
|
|
|
|
|
|
import paddle.v2.dataset.flowers as flowers
|
|
|
|
|
import paddle.v2.dataset.cifar as cifar
|
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
import reader
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2.
|
|
|
|
|
CLASS_DIM = 102
|
|
|
|
|
DATA_DIM = 3 * 32 * 32
|
|
|
|
|
CLASS_DIM = 10
|
|
|
|
|
BATCH_SIZE = 128
|
|
|
|
|
ts = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vgg(input, nums, class_dim):
|
|
|
|
@ -74,6 +76,7 @@ def vgg19(input, class_dim):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
global ts
|
|
|
|
|
paddle.init(use_gpu=False, trainer_count=1)
|
|
|
|
|
image = paddle.layer.data(
|
|
|
|
|
name="image", type=paddle.data_type.dense_vector(DATA_DIM))
|
|
|
|
@ -100,13 +103,13 @@ def main():
|
|
|
|
|
|
|
|
|
|
train_reader = paddle.batch(
|
|
|
|
|
paddle.reader.shuffle(
|
|
|
|
|
flowers.train(),
|
|
|
|
|
cifar.train10(),
|
|
|
|
|
# To use other data, replace the above line with:
|
|
|
|
|
# reader.train_reader('train.list'),
|
|
|
|
|
buf_size=1000),
|
|
|
|
|
batch_size=BATCH_SIZE)
|
|
|
|
|
test_reader = paddle.batch(
|
|
|
|
|
flowers.valid(),
|
|
|
|
|
cifar.test10(),
|
|
|
|
|
# To use other data, replace the above line with:
|
|
|
|
|
# reader.test_reader('val.list'),
|
|
|
|
|
batch_size=BATCH_SIZE)
|
|
|
|
@ -120,10 +123,14 @@ def main():
|
|
|
|
|
|
|
|
|
|
# End batch and end pass event handler
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
global ts
|
|
|
|
|
if isinstance(event, paddle.event.BeginIteration):
|
|
|
|
|
ts = time.time()
|
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
|
if event.batch_id % 1 == 0:
|
|
|
|
|
print "\nPass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
|
|
|
|
print "\nPass %d, Batch %d, Cost %f, %s, spent: %f" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics,
|
|
|
|
|
time.time() - ts)
|
|
|
|
|
if isinstance(event, paddle.event.EndPass):
|
|
|
|
|
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
|
|
|
|
|
trainer.save_parameter_to_tar(f)
|
|
|
|
@ -137,3 +144,4 @@ def main():
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main()
|
|
|
|
|
|
|
|
|
|