|
|
@ -27,6 +27,8 @@ from mindspore._checkparam import args_type_check
|
|
|
|
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
|
|
|
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
|
|
|
_reset_auto_parallel_context
|
|
|
|
_reset_auto_parallel_context
|
|
|
|
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
|
|
|
|
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
|
|
|
|
|
|
|
|
from .device_target import __device_target__
|
|
|
|
|
|
|
|
from .package_name import __package_name__
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
|
|
|
|
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
|
|
|
|
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context',
|
|
|
|
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context',
|
|
|
@ -560,6 +562,10 @@ def set_context(**kwargs):
|
|
|
|
# set device target first
|
|
|
|
# set device target first
|
|
|
|
if 'device_target' in kwargs:
|
|
|
|
if 'device_target' in kwargs:
|
|
|
|
ctx.set_device_target(kwargs['device_target'])
|
|
|
|
ctx.set_device_target(kwargs['device_target'])
|
|
|
|
|
|
|
|
device = kwargs['device_target']
|
|
|
|
|
|
|
|
if not device.lower() in __device_target__:
|
|
|
|
|
|
|
|
raise ValueError(f"Error, package type {__package_name__} support device type {__device_target__}, "
|
|
|
|
|
|
|
|
f"but got device target {device}")
|
|
|
|
device = ctx.get_param(ms_ctx_param.device_target)
|
|
|
|
device = ctx.get_param(ms_ctx_param.device_target)
|
|
|
|
for key, value in kwargs.items():
|
|
|
|
for key, value in kwargs.items():
|
|
|
|
if not _check_target_specific_cfgs(device, key):
|
|
|
|
if not _check_target_specific_cfgs(device, key):
|
|
|
|