|
|
@ -332,6 +332,17 @@ class _Context:
|
|
|
|
def check_bprop(self, check_bprop_flag):
|
|
|
|
def check_bprop(self, check_bprop_flag):
|
|
|
|
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
|
|
|
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def max_device_memory(self):
|
|
|
|
|
|
|
|
return self._context_handle.get_max_device_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@max_device_memory.setter
|
|
|
|
|
|
|
|
def max_device_memory(self, max_device_memory):
|
|
|
|
|
|
|
|
if not check_input_format(max_device_memory):
|
|
|
|
|
|
|
|
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
|
|
|
|
|
|
|
|
max_device_memory_value = float(max_device_memory[:-2])
|
|
|
|
|
|
|
|
self._context_handle.set_max_device_memory(max_device_memory_value)
|
|
|
|
|
|
|
|
|
|
|
|
def check_input_format(x):
|
|
|
|
def check_input_format(x):
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
|
|
|
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
|
|
@ -459,7 +470,7 @@ def reset_auto_parallel_context():
|
|
|
|
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
|
|
|
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
|
|
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
|
|
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
|
|
|
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
|
|
|
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
|
|
|
check_bprop=bool)
|
|
|
|
check_bprop=bool, max_device_memory=str)
|
|
|
|
def set_context(**kwargs):
|
|
|
|
def set_context(**kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Sets context for running environment.
|
|
|
|
Sets context for running environment.
|
|
|
@ -511,6 +522,7 @@ def set_context(**kwargs):
|
|
|
|
separated by colons; single operator can choose op_trace, op_trace cannot be combined with
|
|
|
|
separated by colons; single operator can choose op_trace, op_trace cannot be combined with
|
|
|
|
training_trace and task_trace. Default: "training_trace".
|
|
|
|
training_trace and task_trace. Default: "training_trace".
|
|
|
|
check_bprop (bool): Whether to check bprop. Default: False.
|
|
|
|
check_bprop (bool): Whether to check bprop. Default: False.
|
|
|
|
|
|
|
|
max_device_memory (str): Sets the maximum memory available for device. Default: "1024GB".
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
Raises:
|
|
|
|
ValueError: If input key is not an attribute in context.
|
|
|
|
ValueError: If input key is not an attribute in context.
|
|
|
@ -530,6 +542,7 @@ def set_context(**kwargs):
|
|
|
|
>>> device_target="Ascend",device_id=0, save_graphs=True,
|
|
|
|
>>> device_target="Ascend",device_id=0, save_graphs=True,
|
|
|
|
>>> save_graphs_path="/mindspore")
|
|
|
|
>>> save_graphs_path="/mindspore")
|
|
|
|
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
|
|
|
|
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
|
|
|
|
|
|
|
|
>>> context.set_context(max_device_memory="3.5GB")
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
for key, value in kwargs.items():
|
|
|
|
for key, value in kwargs.items():
|
|
|
|
if not hasattr(_context(), key):
|
|
|
|
if not hasattr(_context(), key):
|
|
|
|