|
|
|
@ -184,13 +184,16 @@ def main():
|
|
|
|
|
parser.add_argument("-d", "--dataSource", help="mnist or cifar")
|
|
|
|
|
parser.add_argument("--useGpu", default="1",
|
|
|
|
|
help="1 means use gpu for training")
|
|
|
|
|
parser.add_argument("--gpuId", default="0",
|
|
|
|
|
help="the gpu_id parameter")
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
dataSource = args.dataSource
|
|
|
|
|
useGpu = args.useGpu
|
|
|
|
|
assert dataSource in ["mnist", "cifar"]
|
|
|
|
|
assert useGpu in ["0", "1"]
|
|
|
|
|
|
|
|
|
|
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100')
|
|
|
|
|
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
|
|
|
|
|
'--gpu_id=' + args.gpuId)
|
|
|
|
|
gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource)
|
|
|
|
|
dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource)
|
|
|
|
|
generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource)
|
|
|
|
|