|
|
|
@ -87,11 +87,8 @@ def load_mnist_data(imageFile):
|
|
|
|
|
else:
|
|
|
|
|
n = 10000
|
|
|
|
|
|
|
|
|
|
data = numpy.zeros((n, 28*28), dtype = "float32")
|
|
|
|
|
|
|
|
|
|
for i in range(n):
|
|
|
|
|
pixels = numpy.fromfile(f, 'ubyte', count=28*28)
|
|
|
|
|
data[i, :] = pixels / 255.0 * 2.0 - 1.0
|
|
|
|
|
data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28))
|
|
|
|
|
data = data / 255.0 * 2.0 - 1.0
|
|
|
|
|
|
|
|
|
|
f.close()
|
|
|
|
|
return data
|
|
|
|
@ -235,7 +232,7 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
data_np = load_uniform_data()
|
|
|
|
|
|
|
|
|
|
# this create a gradient machine for discriminator
|
|
|
|
|
# this creates a gradient machine for discriminator
|
|
|
|
|
dis_training_machine = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
dis_conf.model_config)
|
|
|
|
|
# this create a gradient machine for generator
|
|
|
|
@ -243,7 +240,7 @@ def main():
|
|
|
|
|
gen_conf.model_config)
|
|
|
|
|
|
|
|
|
|
# generator_machine is used to generate data only, which is used for
|
|
|
|
|
# training discrinator
|
|
|
|
|
# training discriminator
|
|
|
|
|
logger.info(str(generator_conf.model_config))
|
|
|
|
|
generator_machine = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
generator_conf.model_config)
|
|
|
|
|