|
|
@ -480,6 +480,26 @@ class _AutoParallelContext:
|
|
|
|
self.check_context_handle()
|
|
|
|
self.check_context_handle()
|
|
|
|
return self._context_handle.get_enable_parallel_optimizer()
|
|
|
|
return self._context_handle.get_enable_parallel_optimizer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_communi_parallel_mode(self, communi_parallel_mode):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Set communication parallel mode.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
communi_parallel_mode (str): The communication parallel mode.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
|
|
ValueError: If parallel mode is not supported.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
|
|
|
ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
|
|
|
|
|
|
|
|
if ret is False:
|
|
|
|
|
|
|
|
raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_communi_parallel_mode(self):
|
|
|
|
|
|
|
|
"""Get communication parallel mode."""
|
|
|
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
|
|
|
return self._context_handle.get_communi_parallel_mode()
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
def reset(self):
|
|
|
|
"""Reset all settings."""
|
|
|
|
"""Reset all settings."""
|
|
|
|
self.check_context_handle()
|
|
|
|
self.check_context_handle()
|
|
|
@ -518,7 +538,8 @@ _set_auto_parallel_context_func_map = {
|
|
|
|
"full_batch": auto_parallel_context().set_full_batch,
|
|
|
|
"full_batch": auto_parallel_context().set_full_batch,
|
|
|
|
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
|
|
|
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
|
|
|
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
|
|
|
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
|
|
|
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
|
|
|
|
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
|
|
|
|
|
|
|
|
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_get_auto_parallel_context_func_map = {
|
|
|
|
_get_auto_parallel_context_func_map = {
|
|
|
@ -536,14 +557,16 @@ _get_auto_parallel_context_func_map = {
|
|
|
|
"full_batch": auto_parallel_context().get_full_batch,
|
|
|
|
"full_batch": auto_parallel_context().get_full_batch,
|
|
|
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
|
|
|
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
|
|
|
|
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
|
|
|
|
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
|
|
|
|
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
|
|
|
|
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
|
|
|
|
|
|
|
|
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
|
|
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
|
|
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
|
|
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
|
|
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
|
|
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
|
|
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
|
|
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
|
|
|
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
|
|
|
|
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
|
|
|
|
|
|
|
communi_parallel_mode=str)
|
|
|
|
|
|
|
|
|
|
|
|
def _set_auto_parallel_context(**kwargs):
|
|
|
|
def _set_auto_parallel_context(**kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -592,6 +615,14 @@ def _set_auto_parallel_context(**kwargs):
|
|
|
|
the devices are distributed alone the pipeline. The total devices will be divided into
|
|
|
|
the devices are distributed alone the pipeline. The total devices will be divided into
|
|
|
|
'pipeline_stags' stages. This currently could only be used when
|
|
|
|
'pipeline_stags' stages. This currently could only be used when
|
|
|
|
parall mode semi_auto_parallel is enabled. Default: 0
|
|
|
|
parall mode semi_auto_parallel is enabled. Default: 0
|
|
|
|
|
|
|
|
communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
|
|
|
|
|
|
|
|
"same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- all_group_parallel: All communication groups are in parallel.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- same_server_group_parallel: Only the communication groups within the same server are parallel.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- no_group_parallel: All communication groups are not parallel.
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
Raises:
|
|
|
|
ValueError: If input key is not attribute in auto parallel context.
|
|
|
|
ValueError: If input key is not attribute in auto parallel context.
|
|
|
|