!11648 fix param check for unexpected_format

From: @jiang-shuqiang
Reviewed-by: @yelihua
Signed-off-by:
pull/11648/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 78d250c7d5

@ -88,7 +88,7 @@ def _make_directory(path: str):
else: else:
logger.debug("The directory(%s) doesn't exist, will create it", path) logger.debug("The directory(%s) doesn't exist, will create it", path)
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True, mode=0o700)
real_path = path real_path = path
except PermissionError as e: except PermissionError as e:
logger.error("No write permission on the directory(%r), error = %r", path, e) logger.error("No write permission on the directory(%r), error = %r", path, e)

@ -138,7 +138,7 @@ class WriterPool(ctx.Process):
for writer in self._writers[:]: for writer in self._writers[:]:
try: try:
writer.write(plugin, data) writer.write(plugin, data)
except RuntimeError as exc: except (RuntimeError, OSError) as exc:
logger.error(str(exc)) logger.error(str(exc))
self._writers.remove(writer) self._writers.remove(writer)
writer.close() writer.close()

@ -36,7 +36,7 @@ _summary_lock = threading.Lock()
# cache the summary data # cache the summary data
_summary_tensor_cache = {} _summary_tensor_cache = {}
_DEFAULT_EXPORT_OPTIONS = { _DEFAULT_EXPORT_OPTIONS = {
'tensor_format': 'npy', 'tensor_format': {'npy'},
} }
@ -68,14 +68,22 @@ def process_export_options(export_options):
check_value_type('export_options', export_options, [dict, type(None)]) check_value_type('export_options', export_options, [dict, type(None)])
for param_name in export_options: for export_option, export_format in export_options.items():
check_value_type(param_name, param_name, [str]) check_value_type('export_option', export_option, [str])
check_value_type('export_format', export_format, [str])
unexpected_params = set(export_options) - set(_DEFAULT_EXPORT_OPTIONS) unexpected_params = set(export_options) - set(_DEFAULT_EXPORT_OPTIONS)
if unexpected_params: if unexpected_params:
raise ValueError(f'For `export_options` the keys {unexpected_params} are unsupported, ' raise ValueError(f'For `export_options` the keys {unexpected_params} are unsupported, '
f'expect the follow keys: {list(_DEFAULT_EXPORT_OPTIONS.keys())}') f'expect the follow keys: {list(_DEFAULT_EXPORT_OPTIONS.keys())}')
for export_option, export_format in export_options.items():
unexpected_format = {export_format} - _DEFAULT_EXPORT_OPTIONS.get(export_option)
if unexpected_format:
raise ValueError(
f'For `export_options`, the export_format {unexpected_format} are unsupported for {export_option}, '
f'expect the follow values: {list(_DEFAULT_EXPORT_OPTIONS.get(export_option))}')
for item in set(export_options): for item in set(export_options):
check_value_type(item, export_options.get(item), [str, type(None)]) check_value_type(item, export_options.get(item), [str, type(None)])

@ -28,11 +28,13 @@ from mindspore.train.callback import SummaryCollector
from mindspore.train.callback import _InternalCallbackParam from mindspore.train.callback import _InternalCallbackParam
from mindspore.train.summary.enums import ModeEnum, PluginEnum from mindspore.train.summary.enums import ModeEnum, PluginEnum
from mindspore.train.summary import SummaryRecord from mindspore.train.summary import SummaryRecord
from mindspore.train.summary.summary_record import _DEFAULT_EXPORT_OPTIONS
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops.operations import Add from mindspore.ops.operations import Add
_VALUE_CACHE = list() _VALUE_CACHE = list()
@ -143,6 +145,24 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value) assert expected_msg == str(exc.value)
@pytest.mark.parametrize("export_options", [
{
"tensor_format": "npz"
}
])
def test_params_with_tensor_format_type_error(self, export_options):
"""Test type error scenario for collect specified data param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir, export_options=export_options)
unexpected_format = {export_options.get("tensor_format")}
expected_msg = f'For `export_options`, the export_format {unexpected_format} are ' \
f'unsupported for tensor_format, expect the follow values: ' \
f'{list(_DEFAULT_EXPORT_OPTIONS.get("tensor_format"))}'
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("export_options", [123]) @pytest.mark.parametrize("export_options", [123])
def test_params_with_export_options_type_error(self, export_options): def test_params_with_export_options_type_error(self, export_options):
"""Test type error scenario for collect specified data param.""" """Test type error scenario for collect specified data param."""

Loading…
Cancel
Save