|
|
@ -49,7 +49,7 @@ def context_device_init(config):
|
|
|
|
if config.run_distribute:
|
|
|
|
if config.run_distribute:
|
|
|
|
context.set_auto_parallel_context(device_num=config.rank_size,
|
|
|
|
context.set_auto_parallel_context(device_num=config.rank_size,
|
|
|
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
gradients_mean=True, set_all_reduce_fusion_split_indices=[140])
|
|
|
|
gradients_mean=True, all_reduce_fusion_config=[140])
|
|
|
|
init()
|
|
|
|
init()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise ValueError("Only support CPU, GPU and Ascend.")
|
|
|
|
raise ValueError("Only support CPU, GPU and Ascend.")
|
|
|
|