|
|
|
@ -113,24 +113,24 @@ class _AutoParallelContext:
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_mirror_mean()
|
|
|
|
|
|
|
|
|
|
def set_cast_before_mirror(self, cast_before_mirror):
|
|
|
|
|
def set_gradient_fp32_sync(self, gradient_fp32_sync):
|
|
|
|
|
"""
|
|
|
|
|
Set cast_before_mirror.
|
|
|
|
|
Set gradient_fp32_sync.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
If cast_before_mirror is true,
|
|
|
|
|
If gradient_fp32_sync is true,
|
|
|
|
|
it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cast_before_mirror (bool): The cast_before_mirror flag.
|
|
|
|
|
gradient_fp32_sync (bool): The gradient_fp32_sync flag.
|
|
|
|
|
"""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
self._context_handle.set_cast_before_mirror(cast_before_mirror)
|
|
|
|
|
self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
|
|
|
|
|
|
|
|
|
|
def get_cast_before_mirror(self):
|
|
|
|
|
"""Get cast_before_mirror flag."""
|
|
|
|
|
def get_gradient_fp32_sync(self):
|
|
|
|
|
"""Get gradient_fp32_sync flag."""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_cast_before_mirror()
|
|
|
|
|
return self._context_handle.get_gradient_fp32_sync()
|
|
|
|
|
|
|
|
|
|
def set_loss_repeated_mean(self, loss_repeated_mean):
|
|
|
|
|
"""
|
|
|
|
@ -152,21 +152,6 @@ class _AutoParallelContext:
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_loss_repeated_mean()
|
|
|
|
|
|
|
|
|
|
def set_communication_backend(self, communication_backend):
|
|
|
|
|
"""
|
|
|
|
|
Set communication backend.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
communication_backend (str): The communication backend.
|
|
|
|
|
"""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
self._context_handle.set_communication_backend(communication_backend)
|
|
|
|
|
|
|
|
|
|
def get_communication_backend(self):
|
|
|
|
|
"""Get communication backend."""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_communication_backend()
|
|
|
|
|
|
|
|
|
|
def set_parallel_mode(self, parallel_mode):
|
|
|
|
|
"""
|
|
|
|
|
Set parallel mode for auto parallel.
|
|
|
|
@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = {
|
|
|
|
|
"device_num": auto_parallel_context().set_device_num,
|
|
|
|
|
"global_rank": auto_parallel_context().set_global_rank,
|
|
|
|
|
"mirror_mean": auto_parallel_context().set_mirror_mean,
|
|
|
|
|
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
|
|
|
|
|
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
|
|
|
|
|
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
|
|
|
|
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
|
|
|
|
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
|
|
|
|
@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = {
|
|
|
|
|
"device_num": auto_parallel_context().get_device_num,
|
|
|
|
|
"global_rank": auto_parallel_context().get_global_rank,
|
|
|
|
|
"mirror_mean": auto_parallel_context().get_mirror_mean,
|
|
|
|
|
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
|
|
|
|
|
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
|
|
|
|
|
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
|
|
|
|
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
|
|
|
|
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
|
|
|
|
@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = {
|
|
|
|
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool,
|
|
|
|
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
|
|
|
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
|
|
|
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
|
|
|
@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs):
|
|
|
|
|
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
|
|
|
|
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
|
|
|
|
|
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
|
|
|
|
|
calculations. Default: True.
|
|
|
|
|
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
|
|
|
|
|
calculations. Default: True.
|
|
|
|
|
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
|
|
|
|
|
Default: True.
|
|
|
|
|
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
|
|
|
|
|
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
|
|
|
|
|
|
|
|
|
@ -577,7 +563,7 @@ def _reset_auto_parallel_context():
|
|
|
|
|
- device_num: 1.
|
|
|
|
|
- global_rank: 0.
|
|
|
|
|
- mirror_mean: False.
|
|
|
|
|
- cast_before_mirror: True.
|
|
|
|
|
- gradient_fp32_sync: True.
|
|
|
|
|
- parallel_mode: "stand_alone".
|
|
|
|
|
- parameter_broadcast: False.
|
|
|
|
|
- strategy_ckpt_load_file: ""
|
|
|
|
|