|
|
|
@ -234,6 +234,17 @@ class _Context:
|
|
|
|
|
if not success:
|
|
|
|
|
raise RuntimeError("Device id set failed!!!")
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def max_call_depth(self):
|
|
|
|
|
return self._context_handle.get_max_call_depth()
|
|
|
|
|
|
|
|
|
|
@max_call_depth.setter
|
|
|
|
|
def max_call_depth(self, max_call_depth):
|
|
|
|
|
if max_call_depth <= 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Max call depth must be greater than 0, but got {}".format(max_call_depth))
|
|
|
|
|
self._context_handle.set_max_call_depth(max_call_depth)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def enable_auto_mixed_precision(self):
|
|
|
|
|
return self._context_handle.get_auto_mixed_precision_flag()
|
|
|
|
@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs):
|
|
|
|
|
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
|
|
|
|
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
|
|
|
|
|
data parallel training in the benefit of time and memory saving.
|
|
|
|
|
max_call_depth(int): Specify the function call depth limit. Default: 1000.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs):
|
|
|
|
|
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
|
|
|
|
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
|
|
|
|
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
|
|
|
|
>>> context.set_auto_parallel_context(max_call_depth=80)
|
|
|
|
|
"""
|
|
|
|
|
_set_auto_parallel_context(**kwargs)
|
|
|
|
|
|
|
|
|
@ -532,7 +545,7 @@ def reset_auto_parallel_context():
|
|
|
|
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
|
|
|
|
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
|
|
|
|
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
|
|
|
|
enable_sparse=bool)
|
|
|
|
|
enable_sparse=bool, max_call_depth=int)
|
|
|
|
|
def set_context(**kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Sets context for running environment.
|
|
|
|
|