diff --git a/model_zoo/official/cv/alexnet/train.py b/model_zoo/official/cv/alexnet/train.py index 59aa28f18e..cf0c2453eb 100644 --- a/model_zoo/official/cv/alexnet/train.py +++ b/model_zoo/official/cv/alexnet/train.py @@ -55,8 +55,12 @@ if __name__ == "__main__": parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)') args = parser.parse_args() + device_num = int(os.environ.get("DEVICE_NUM", 1)) if args.dataset_name == "cifar10": cfg = alexnet_cifar10_cfg + if device_num > 1: + cfg.learning_rate = cfg.learning_rate * device_num + cfg.epoch_size = cfg.epoch_size * 2 elif args.dataset_name == "imagenet": cfg = alexnet_imagenet_cfg else: @@ -65,14 +69,11 @@ if __name__ == "__main__": device_target = args.device_target context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(save_graphs=False) - device_num = int(os.environ.get("DEVICE_NUM", 1)) if device_target == "Ascend": context.set_context(device_id=args.device_id) if device_num > 1: - cfg.learning_rate = cfg.learning_rate * device_num - cfg.epoch_size = cfg.epoch_size * 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)