|
|
|
@ -55,8 +55,12 @@ if __name__ == "__main__":
|
|
|
|
|
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
|
|
|
|
if args.dataset_name == "cifar10":
|
|
|
|
|
cfg = alexnet_cifar10_cfg
|
|
|
|
|
if device_num > 1:
|
|
|
|
|
cfg.learning_rate = cfg.learning_rate * device_num
|
|
|
|
|
cfg.epoch_size = cfg.epoch_size * 2
|
|
|
|
|
elif args.dataset_name == "imagenet":
|
|
|
|
|
cfg = alexnet_imagenet_cfg
|
|
|
|
|
else:
|
|
|
|
@ -65,14 +69,11 @@ if __name__ == "__main__":
|
|
|
|
|
device_target = args.device_target
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
|
|
|
|
context.set_context(save_graphs=False)
|
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
|
|
|
|
|
|
|
|
|
if device_target == "Ascend":
|
|
|
|
|
context.set_context(device_id=args.device_id)
|
|
|
|
|
|
|
|
|
|
if device_num > 1:
|
|
|
|
|
cfg.learning_rate = cfg.learning_rate * device_num
|
|
|
|
|
cfg.epoch_size = cfg.epoch_size * 2
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
gradients_mean=True)
|
|
|
|
|