|
|
|
@ -57,7 +57,9 @@ if __name__ == '__main__':
|
|
|
|
|
device_id = int(os.getenv('DEVICE_ID'))
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
gradients_mean=True,
|
|
|
|
|
all_reduce_fusion_config=[9, 11])
|
|
|
|
|
init()
|
|
|
|
|
rank_id = int(os.environ.get('RANK_ID'))
|
|
|
|
|
elif args_opt.device_target == "GPU":
|
|
|
|
|