|
|
|
@ -185,13 +185,20 @@ class _AutoParallelContext:
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_parallel_mode()
|
|
|
|
|
|
|
|
|
|
def set_strategy_search_mode(self, strategy_search_mode):
|
|
|
|
|
def set_strategy_search_mode(self, auto_parallel_search_mode):
|
|
|
|
|
"""
|
|
|
|
|
Set search mode of strategy.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
auto_parallel_search_mode (str): The search mode of strategy.
|
|
|
|
|
"""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
ret = self._context_handle.set_strategy_search_mode(strategy_search_mode)
|
|
|
|
|
ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
|
|
|
|
|
if ret is False:
|
|
|
|
|
raise ValueError("Strategy search mode does not support {}".format(strategy_search_mode))
|
|
|
|
|
raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
|
|
|
|
|
|
|
|
|
|
def get_strategy_search_mode(self):
|
|
|
|
|
"""Get search mode of strategy."""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
return self._context_handle.get_strategy_search_mode()
|
|
|
|
|
|
|
|
|
@ -422,6 +429,7 @@ _set_auto_parallel_context_func_map = {
|
|
|
|
|
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
|
|
|
|
|
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
|
|
|
|
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
|
|
|
|
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
|
|
|
|
|
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
|
|
|
|
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
|
|
|
|
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
|
|
|
@ -435,6 +443,7 @@ _get_auto_parallel_context_func_map = {
|
|
|
|
|
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
|
|
|
|
|
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
|
|
|
|
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
|
|
|
|
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
|
|
|
|
|
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
|
|
|
|
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
|
|
|
|
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
|
|
|
@ -442,8 +451,9 @@ _get_auto_parallel_context_func_map = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
|
|
|
|
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
|
|
|
|
|
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool)
|
|
|
|
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
|
|
|
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
|
|
|
|
strategy_ckpt_save_file=str, full_batch=bool)
|
|
|
|
|
def _set_auto_parallel_context(**kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Set auto parallel context.
|
|
|
|
@ -471,6 +481,12 @@ def _set_auto_parallel_context(**kwargs):
|
|
|
|
|
setting parallel strategies.
|
|
|
|
|
|
|
|
|
|
- auto_parallel: Achieving parallelism automatically.
|
|
|
|
|
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
|
|
|
|
|
and "dynamic_programming".
|
|
|
|
|
|
|
|
|
|
- recursive_programming: Recursive programming search mode.
|
|
|
|
|
|
|
|
|
|
- dynamic_programming: Dynamic programming search mode.
|
|
|
|
|
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
|
|
|
|
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
|
|
|
|
broadcast. Default: False.
|
|
|
|
|