minor changes on demo/gan following lzhao4ever comments

avx_docs
wangyang59 8 years ago
parent 531e83542c
commit 5aa597960d

@ -9,4 +9,5 @@ Then you can run the command below. The flag -d specifies the training data (cif
$python gan_trainer.py -d cifar --useGpu 1
The generated images will be stored in ./cifar_samples/
The generated images will be stored in ./cifar_samples/
The corresponding models will be stored in ./cifar_params/

@ -1,5 +1,5 @@
#!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it.
# This script downloads the mnist data and unzips it.
set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
rm -rf "$DIR/mnist_data"

@ -38,7 +38,7 @@ sample_dim = 2
settings(
batch_size=128,
learning_rate=1e-4,
learning_method=AdamOptimizer(beta1=0.7)
learning_method=AdamOptimizer(beta1=0.5)
)
def discriminator(sample):

@ -87,11 +87,8 @@ def load_mnist_data(imageFile):
else:
n = 10000
data = numpy.zeros((n, 28*28), dtype = "float32")
for i in range(n):
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
data[i, :] = pixels / 255.0 * 2.0 - 1.0
data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28))
data = data / 255.0 * 2.0 - 1.0
f.close()
return data
@ -235,7 +232,7 @@ def main():
else:
data_np = load_uniform_data()
# this create a gradient machine for discriminator
# this creates a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config)
# this create a gradient machine for generator
@ -243,7 +240,7 @@ def main():
gen_conf.model_config)
# generator_machine is used to generate data only, which is used for
# training discrinator
# training discriminator
logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config)

Loading…
Cancel
Save