|
|
|
@ -344,7 +344,7 @@ def _context():
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
|
|
|
|
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
|
|
|
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
|
|
|
|
all_reduce_fusion_config=list, pipeline_stages=int)
|
|
|
|
|
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
|
|
|
|
|
def set_auto_parallel_context(**kwargs):
|
|
|
|
|
r"""
|
|
|
|
|
Set auto parallel context, which is valid only for Ascend and GPU target.
|
|
|
|
@ -371,6 +371,7 @@ def set_auto_parallel_context(**kwargs):
|
|
|
|
|
all_reduce_fusion_config strategy_ckpt_save_file
|
|
|
|
|
enable_parallel_optimizer full_batch
|
|
|
|
|
\ pipeline_stages
|
|
|
|
|
\ grad_accumulation_step
|
|
|
|
|
=========================== ===========================
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -420,6 +421,8 @@ def set_auto_parallel_context(**kwargs):
|
|
|
|
|
the devices are distributed alone the pipeline. The total devices will be divided into
|
|
|
|
|
'pipeline_stags' stages. This currently could only be used when
|
|
|
|
|
parallel mode semi_auto_parallel is enabled. Default: 1.
|
|
|
|
|
grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode.
|
|
|
|
|
This should be a positive int. Default: 1.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If input key is not attribute in auto parallel context.
|
|
|
|
|