|
|
|
@ -78,7 +78,7 @@ if __name__ == '__main__':
|
|
|
|
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
gradients_mean=True)
|
|
|
|
|
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
|
|
|
|
|
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 150])
|
|
|
|
|
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
|
|
|
|
|
else:
|
|
|
|
|
context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313])
|
|
|
|
|
init()
|
|
|
|
|