|
|
@ -28,11 +28,13 @@ from mindspore.train.callback import SummaryCollector
|
|
|
|
from mindspore.train.callback import _InternalCallbackParam
|
|
|
|
from mindspore.train.callback import _InternalCallbackParam
|
|
|
|
from mindspore.train.summary.enums import ModeEnum, PluginEnum
|
|
|
|
from mindspore.train.summary.enums import ModeEnum, PluginEnum
|
|
|
|
from mindspore.train.summary import SummaryRecord
|
|
|
|
from mindspore.train.summary import SummaryRecord
|
|
|
|
|
|
|
|
from mindspore.train.summary.summary_record import _DEFAULT_EXPORT_OPTIONS
|
|
|
|
from mindspore.nn import Cell
|
|
|
|
from mindspore.nn import Cell
|
|
|
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
|
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
|
|
from mindspore.ops.operations import Add
|
|
|
|
from mindspore.ops.operations import Add
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_VALUE_CACHE = list()
|
|
|
|
_VALUE_CACHE = list()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -143,6 +145,24 @@ class TestSummaryCollector:
|
|
|
|
|
|
|
|
|
|
|
|
assert expected_msg == str(exc.value)
|
|
|
|
assert expected_msg == str(exc.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("export_options", [
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"tensor_format": "npz"
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
def test_params_with_tensor_format_type_error(self, export_options):
|
|
|
|
|
|
|
|
"""Test type error scenario for collect specified data param."""
|
|
|
|
|
|
|
|
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError) as exc:
|
|
|
|
|
|
|
|
SummaryCollector(summary_dir, export_options=export_options)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unexpected_format = {export_options.get("tensor_format")}
|
|
|
|
|
|
|
|
expected_msg = f'For `export_options`, the export_format {unexpected_format} are ' \
|
|
|
|
|
|
|
|
f'unsupported for tensor_format, expect the follow values: ' \
|
|
|
|
|
|
|
|
f'{list(_DEFAULT_EXPORT_OPTIONS.get("tensor_format"))}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert expected_msg == str(exc.value)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("export_options", [123])
|
|
|
|
@pytest.mark.parametrize("export_options", [123])
|
|
|
|
def test_params_with_export_options_type_error(self, export_options):
|
|
|
|
def test_params_with_export_options_type_error(self, export_options):
|
|
|
|
"""Test type error scenario for collect specified data param."""
|
|
|
|
"""Test type error scenario for collect specified data param."""
|
|
|
|