minor changes on demo/gan following lzhao4ever comments

avx_docs
wangyang59 9 years ago
parent 531e83542c
commit 5aa597960d

@ -9,4 +9,5 @@ Then you can run the command below. The flag -d specifies the training data (cif
$python gan_trainer.py -d cifar --useGpu 1 $python gan_trainer.py -d cifar --useGpu 1
The generated images will be stored in ./cifar_samples/ The generated images will be stored in ./cifar_samples/
The corresponding models will be stored in ./cifar_params/

@ -1,5 +1,5 @@
#!/usr/bin/env sh #!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it. # This script downloads the mnist data and unzips it.
set -e set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )" DIR="$( cd "$(dirname "$0")" ; pwd -P )"
rm -rf "$DIR/mnist_data" rm -rf "$DIR/mnist_data"

@ -38,7 +38,7 @@ sample_dim = 2
settings( settings(
batch_size=128, batch_size=128,
learning_rate=1e-4, learning_rate=1e-4,
learning_method=AdamOptimizer(beta1=0.7) learning_method=AdamOptimizer(beta1=0.5)
) )
def discriminator(sample): def discriminator(sample):

@ -87,11 +87,8 @@ def load_mnist_data(imageFile):
else: else:
n = 10000 n = 10000
data = numpy.zeros((n, 28*28), dtype = "float32") data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28))
data = data / 255.0 * 2.0 - 1.0
for i in range(n):
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
data[i, :] = pixels / 255.0 * 2.0 - 1.0
f.close() f.close()
return data return data
@ -235,7 +232,7 @@ def main():
else: else:
data_np = load_uniform_data() data_np = load_uniform_data()
# this create a gradient machine for discriminator # this creates a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto( dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config) dis_conf.model_config)
# this create a gradient machine for generator # this create a gradient machine for generator
@ -243,7 +240,7 @@ def main():
gen_conf.model_config) gen_conf.model_config)
# generator_machine is used to generate data only, which is used for # generator_machine is used to generate data only, which is used for
# training discrinator # training discriminator
logger.info(str(generator_conf.model_config)) logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto( generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config) generator_conf.model_config)

Loading…
Cancel
Save