From: @wukesong
Reviewed-by: @yingjy,@oacjiewen
Signed-off-by: @yingjy
pull/8924/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0e58beea2f

@ -55,8 +55,12 @@ if __name__ == "__main__":
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)')
args = parser.parse_args()
device_num = int(os.environ.get("DEVICE_NUM", 1))
if args.dataset_name == "cifar10":
cfg = alexnet_cifar10_cfg
if device_num > 1:
cfg.learning_rate = cfg.learning_rate * device_num
cfg.epoch_size = cfg.epoch_size * 2
elif args.dataset_name == "imagenet":
cfg = alexnet_imagenet_cfg
else:
@ -65,14 +69,11 @@ if __name__ == "__main__":
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=False)
device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_target == "Ascend":
context.set_context(device_id=args.device_id)
if device_num > 1:
cfg.learning_rate = cfg.learning_rate * device_num
cfg.epoch_size = cfg.epoch_size * 2
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)

Loading…
Cancel
Save