|
|
|
@ -90,10 +90,8 @@ def load_mnist_data(imageFile):
|
|
|
|
|
data = numpy.zeros((n, 28*28), dtype = "float32")
|
|
|
|
|
|
|
|
|
|
for i in range(n):
|
|
|
|
|
pixels = []
|
|
|
|
|
for j in range(28 * 28):
|
|
|
|
|
pixels.append(float(ord(f.read(1))) / 255.0 * 2.0 - 1.0)
|
|
|
|
|
data[i, :] = pixels
|
|
|
|
|
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
|
|
|
|
|
data[i, :] = pixels / 255.0 * 2.0 - 1.0
|
|
|
|
|
|
|
|
|
|
f.close()
|
|
|
|
|
return data
|
|
|
|
@ -129,7 +127,7 @@ def merge(images, size):
|
|
|
|
|
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
|
|
|
|
|
return img.astype('uint8')
|
|
|
|
|
|
|
|
|
|
def saveImages(images, path):
|
|
|
|
|
def save_images(images, path):
|
|
|
|
|
merged_img = merge(images, [8, 8])
|
|
|
|
|
if merged_img.shape[2] == 1:
|
|
|
|
|
im = Image.fromarray(numpy.squeeze(merged_img)).convert('RGB')
|
|
|
|
@ -208,8 +206,14 @@ def main():
|
|
|
|
|
assert dataSource in ["mnist", "cifar", "uniform"]
|
|
|
|
|
assert useGpu in ["0", "1"]
|
|
|
|
|
|
|
|
|
|
if not os.path.exists("./%s_samples/" % dataSource):
|
|
|
|
|
os.makedirs("./%s_samples/" % dataSource)
|
|
|
|
|
|
|
|
|
|
if not os.path.exists("./%s_params/" % dataSource):
|
|
|
|
|
os.makedirs("./%s_params/" % dataSource)
|
|
|
|
|
|
|
|
|
|
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
|
|
|
|
|
'--gpu_id=' + args.gpuId)
|
|
|
|
|
'--gpu_id=' + args.gpuId, '--save_dir=' + "./%s_params/" % dataSource)
|
|
|
|
|
|
|
|
|
|
if dataSource == "uniform":
|
|
|
|
|
conf = "gan_conf.py"
|
|
|
|
@ -231,9 +235,6 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
data_np = load_uniform_data()
|
|
|
|
|
|
|
|
|
|
if not os.path.exists("./%s_samples/" % dataSource):
|
|
|
|
|
os.makedirs("./%s_samples/" % dataSource)
|
|
|
|
|
|
|
|
|
|
# this create a gradient machine for discriminator
|
|
|
|
|
dis_training_machine = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
dis_conf.model_config)
|
|
|
|
@ -321,7 +322,7 @@ def main():
|
|
|
|
|
if dataSource == "uniform":
|
|
|
|
|
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
|
|
|
|
|
else:
|
|
|
|
|
saveImages(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
|
|
|
|
|
save_images(fake_samples, "./%s_samples/train_pass%s.png" % (dataSource, train_pass))
|
|
|
|
|
dis_trainer.finishTrain()
|
|
|
|
|
gen_trainer.finishTrain()
|
|
|
|
|
|
|
|
|
|