add gpu_id flag in demo/gan

avx_docs
wangyang59 9 years ago
parent d8aada072b
commit 4878f0783b

@ -184,13 +184,16 @@ def main():
parser.add_argument("-d", "--dataSource", help="mnist or cifar")
parser.add_argument("--useGpu", default="1",
help="1 means use gpu for training")
parser.add_argument("--gpuId", default="0",
help="the gpu_id parameter")
args = parser.parse_args()
dataSource = args.dataSource
useGpu = args.useGpu
assert dataSource in ["mnist", "cifar"]
assert useGpu in ["0", "1"]
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100')
api.initPaddle('--use_gpu=' + useGpu, '--dot_period=10', '--log_period=100',
'--gpu_id=' + args.gpuId)
gen_conf = parse_config("gan_conf_image.py", "mode=generator_training,data=" + dataSource)
dis_conf = parse_config("gan_conf_image.py", "mode=discriminator_training,data=" + dataSource)
generator_conf = parse_config("gan_conf_image.py", "mode=generator,data=" + dataSource)

Loading…
Cancel
Save