@ -31,8 +31,8 @@ def plot2DScatter(data, outputfile):
'''
x = data [ : , 0 ]
y = data [ : , 1 ]
print " The mean vector is %s " % numpy . mean ( data , 0 )
print " The std vector is %s " % numpy . std ( data , 0 )
logger . info ( " The mean vector is %s " % numpy . mean ( data , 0 ) )
logger . info ( " The std vector is %s " % numpy . std ( data , 0 ) )
heatmap , xedges , yedges = numpy . histogram2d ( x , y , bins = 50 )
extent = [ xedges [ 0 ] , xedges [ - 1 ] , yedges [ 0 ] , yedges [ - 1 ] ]
@ -192,42 +192,42 @@ def get_layer_size(model_conf, layer_name):
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " -d " , " --data S ource" , help = " mnist or cifar or uniform " )
parser . add_argument ( " --use G pu" , default = " 1 " ,
parser . add_argument ( " -d " , " --data _s ource" , help = " mnist or cifar or uniform " )
parser . add_argument ( " --use _g pu" , default = " 1 " ,
help = " 1 means use gpu for training " )
parser . add_argument ( " --gpu I d" , default = " 0 " ,
parser . add_argument ( " --gpu _i d" , default = " 0 " ,
help = " the gpu_id parameter " )
args = parser . parse_args ( )
data Source = args . dataS ource
use Gpu = args . useG pu
assert data S ource in [ " mnist " , " cifar " , " uniform " ]
assert use G pu in [ " 0 " , " 1 " ]
data _source = args . data_s ource
use _gpu = args . use_g pu
assert data _s ource in [ " mnist " , " cifar " , " uniform " ]
assert use _g pu in [ " 0 " , " 1 " ]
if not os . path . exists ( " ./ %s _samples/ " % data S ource) :
os . makedirs ( " ./ %s _samples/ " % data S ource)
if not os . path . exists ( " ./ %s _samples/ " % data _s ource) :
os . makedirs ( " ./ %s _samples/ " % data _s ource)
if not os . path . exists ( " ./ %s _params/ " % data S ource) :
os . makedirs ( " ./ %s _params/ " % data S ource)
if not os . path . exists ( " ./ %s _params/ " % data _s ource) :
os . makedirs ( " ./ %s _params/ " % data _s ource)
api . initPaddle ( ' --use_gpu= ' + use G pu, ' --dot_period=10 ' , ' --log_period=100 ' ,
' --gpu_id= ' + args . gpu I d, ' --save_dir= ' + " ./ %s _params/ " % data S ource)
api . initPaddle ( ' --use_gpu= ' + use _g pu, ' --dot_period=10 ' , ' --log_period=100 ' ,
' --gpu_id= ' + args . gpu _i d, ' --save_dir= ' + " ./ %s _params/ " % data _s ource)
if data S ource == " uniform " :
if data _s ource == " uniform " :
conf = " gan_conf.py "
num_iter = 10000
else :
conf = " gan_conf_image.py "
num_iter = 1000
gen_conf = parse_config ( conf , " mode=generator_training,data= " + data S ource)
dis_conf = parse_config ( conf , " mode=discriminator_training,data= " + data S ource)
generator_conf = parse_config ( conf , " mode=generator,data= " + data S ource)
gen_conf = parse_config ( conf , " mode=generator_training,data= " + data _s ource)
dis_conf = parse_config ( conf , " mode=discriminator_training,data= " + data _s ource)
generator_conf = parse_config ( conf , " mode=generator,data= " + data _s ource)
batch_size = dis_conf . opt_config . batch_size
noise_dim = get_layer_size ( gen_conf . model_config , " noise " )
if data S ource == " mnist " :
if data _s ource == " mnist " :
data_np = load_mnist_data ( " ./data/mnist_data/train-images-idx3-ubyte " )
elif data S ource == " cifar " :
elif data _s ource == " cifar " :
data_np = load_cifar_data ( " ./data/cifar-10-batches-py/ " )
else :
data_np = load_uniform_data ( )
@ -308,7 +308,9 @@ def main():
else :
curr_train = " gen "
curr_strike = 1
gen_trainer . trainOneDataBatch ( batch_size , data_batch_gen )
gen_trainer . trainOneDataBatch ( batch_size , data_batch_gen )
# TODO: add API for paddle to allow true parameter sharing between different GradientMachines
# so that we do not need to copy shared parameters.
copy_shared_parameters ( gen_training_machine , dis_training_machine )
copy_shared_parameters ( gen_training_machine , generator_machine )
@ -316,10 +318,10 @@ def main():
gen_trainer . finishTrainPass ( )
# At the end of each pass, save the generated samples/images
fake_samples = get_fake_samples ( generator_machine , batch_size , noise )
if data S ource == " uniform " :
plot2DScatter ( fake_samples , " ./ %s _samples/train_pass %s .png " % ( data S ource, train_pass ) )
if data _s ource == " uniform " :
plot2DScatter ( fake_samples , " ./ %s _samples/train_pass %s .png " % ( data _s ource, train_pass ) )
else :
save_images ( fake_samples , " ./ %s _samples/train_pass %s .png " % ( data S ource, train_pass ) )
save_images ( fake_samples , " ./ %s _samples/train_pass %s .png " % ( data _s ource, train_pass ) )
dis_trainer . finishTrain ( )
gen_trainer . finishTrain ( )