|
|
|
@ -95,23 +95,23 @@ class _AutoParallelContext:
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_global_rank()
|
|
|
|
|
|
|
|
|
|
def set_mirror_mean(self, mirror_mean):
|
|
|
|
|
def set_gradients_mean(self, gradients_mean):
|
|
|
|
|
"""
|
|
|
|
|
Set mirror_mean flag.
|
|
|
|
|
Set gradients_mean flag.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
If mirror_mean is true, it will insert a div operator after parameter gradients allreduce.
|
|
|
|
|
If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mirror_mean (bool): The mirror_mean flag.
|
|
|
|
|
gradients_mean (bool): The gradients_mean flag.
|
|
|
|
|
"""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
self._context_handle.set_mirror_mean(mirror_mean)
|
|
|
|
|
self._context_handle.set_gradients_mean(gradients_mean)
|
|
|
|
|
|
|
|
|
|
def get_mirror_mean(self):
|
|
|
|
|
"""Get mirror_mean flag."""
|
|
|
|
|
def get_gradients_mean(self):
|
|
|
|
|
"""Get gradients_mean flag."""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_mirror_mean()
|
|
|
|
|
return self._context_handle.get_gradients_mean()
|
|
|
|
|
|
|
|
|
|
def set_gradient_fp32_sync(self, gradient_fp32_sync):
|
|
|
|
|
"""
|
|
|
|
@ -453,7 +453,7 @@ def auto_parallel_context():
|
|
|
|
|
_set_auto_parallel_context_func_map = {
|
|
|
|
|
"device_num": auto_parallel_context().set_device_num,
|
|
|
|
|
"global_rank": auto_parallel_context().set_global_rank,
|
|
|
|
|
"mirror_mean": auto_parallel_context().set_mirror_mean,
|
|
|
|
|
"gradients_mean": auto_parallel_context().set_gradients_mean,
|
|
|
|
|
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
|
|
|
|
|
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
|
|
|
|
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
|
|
|
@ -468,7 +468,7 @@ _set_auto_parallel_context_func_map = {
|
|
|
|
|
_get_auto_parallel_context_func_map = {
|
|
|
|
|
"device_num": auto_parallel_context().get_device_num,
|
|
|
|
|
"global_rank": auto_parallel_context().get_global_rank,
|
|
|
|
|
"mirror_mean": auto_parallel_context().get_mirror_mean,
|
|
|
|
|
"gradients_mean": auto_parallel_context().get_gradients_mean,
|
|
|
|
|
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
|
|
|
|
|
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
|
|
|
|
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
|
|
|
@ -480,7 +480,7 @@ _get_auto_parallel_context_func_map = {
|
|
|
|
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool,
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
|
|
|
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
|
|
|
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
|
|
|
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
|
|
|
@ -495,7 +495,7 @@ def _set_auto_parallel_context(**kwargs):
|
|
|
|
|
Args:
|
|
|
|
|
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
|
|
|
|
|
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
|
|
|
|
|
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
|
|
|
|
|
gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
|
|
|
|
|
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
|
|
|
|
|
calculations. Default: True.
|
|
|
|
|
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
|
|
|
|
@ -562,7 +562,7 @@ def _reset_auto_parallel_context():
|
|
|
|
|
|
|
|
|
|
- device_num: 1.
|
|
|
|
|
- global_rank: 0.
|
|
|
|
|
- mirror_mean: False.
|
|
|
|
|
- gradients_mean: False.
|
|
|
|
|
- gradient_fp32_sync: True.
|
|
|
|
|
- parallel_mode: "stand_alone".
|
|
|
|
|
- parameter_broadcast: False.
|
|
|
|
|