Support to control whether raise RuntimeError exception in SummaryRecord

1. Support explainer raise an RuntimeError exception
2. fix the ut of SummaryRecord
pull/10436/head
ougongchang 5 years ago
parent 280e127f59
commit 06be546b52

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -199,7 +199,7 @@ class ImageClassificationRunner:
""" """
self._verify_data_n_settings(check_all=True) self._verify_data_n_settings(check_all=True)
with SummaryRecord(self._summary_dir) as summary: with SummaryRecord(self._summary_dir, raise_exception=True) as summary:
print("Start running and writing......") print("Start running and writing......")
begin = time() begin = time()

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -207,7 +207,9 @@ class SummaryCollector(Callback):
self._dataset_sink_mode = True self._dataset_sink_mode = True
def __enter__(self): def __enter__(self):
self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size) self._record = SummaryRecord(log_dir=self._summary_dir,
max_file_size=self._max_file_size,
raise_exception=False)
self._first_step, self._dataset_sink_mode = True, True self._first_step, self._dataset_sink_mode = True, True
return self return self
@ -319,7 +321,14 @@ class SummaryCollector(Callback):
f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}') f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}')
if 'histogram_regular' in specified_data: if 'histogram_regular' in specified_data:
check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None))) regular = specified_data.get('histogram_regular')
check_value_type('histogram_regular', regular, (str, type(None)))
if isinstance(regular, str):
try:
re.match(regular, '')
except re.error as exc:
raise ValueError(f'For `collect_specified_data`, the value of `histogram_regular` '
f'is not a valid regular expression. Detail: {str(exc)}.')
bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'} bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'}
for item in bool_items: for item in bool_items:

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -57,14 +57,17 @@ class WriterPool(ctx.Process):
Args: Args:
base_dir (str): The base directory to hold all the files. base_dir (str): The base directory to hold all the files.
max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes. max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes.
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
filedict (dict): The mapping from plugin to filename. filedict (dict): The mapping from plugin to filename.
""" """
def __init__(self, base_dir, max_file_size, **filedict) -> None: def __init__(self, base_dir, max_file_size, raise_exception=False, **filedict) -> None:
super().__init__() super().__init__()
self._base_dir, self._filedict = base_dir, filedict self._base_dir, self._filedict = base_dir, filedict
self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None
self._max_file_size = max_file_size self._max_file_size = max_file_size
self._raise_exception = raise_exception
self.start() self.start()
def run(self): def run(self):
@ -119,8 +122,14 @@ class WriterPool(ctx.Process):
for writer in self._writers[:]: for writer in self._writers[:]:
try: try:
writer.write(plugin, data) writer.write(plugin, data)
except RuntimeError as e: except RuntimeError as exc:
logger.warning(e.args[0]) logger.error(str(exc))
self._writers.remove(writer)
writer.close()
if self._raise_exception:
raise
except RuntimeWarning as exc:
logger.warning(str(exc))
self._writers.remove(writer) self._writers.remove(writer)
writer.close() writer.close()

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import threading
from collections import defaultdict from collections import defaultdict
from mindspore import log as logger from mindspore import log as logger
from mindspore.nn import Cell
from ..._c_expression import Tensor from ..._c_expression import Tensor
from ..._checkparam import Validator from ..._checkparam import Validator
@ -29,7 +30,7 @@ from ._explain_adapter import check_explain_proto
from ._writer_pool import WriterPool from ._writer_pool import WriterPool
# for the moment, this lock is for caution's sake, # for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening. # there are actually no any concurrences happening.
_summary_lock = threading.Lock() _summary_lock = threading.Lock()
# cache the summary data # cache the summary data
_summary_tensor_cache = {} _summary_tensor_cache = {}
@ -56,10 +57,6 @@ def _get_summary_tensor_data():
return data return data
def _dictlist():
return defaultdict(list)
class SummaryRecord: class SummaryRecord:
""" """
SummaryRecord is used to record the summary data and lineage data. SummaryRecord is used to record the summary data and lineage data.
@ -80,12 +77,13 @@ class SummaryRecord:
file_prefix (str): The prefix of file. Default: "events". file_prefix (str): The prefix of file. Default: "events".
file_suffix (str): The suffix of file. Default: "_MS". file_suffix (str): The suffix of file. Default: "_MS".
network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
max_file_size (Optional[int]): The maximum size of each file that can be written to disk (in bytes). \ max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \
Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`.
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
Raises: Raises:
TypeError: If the type of `max_file_size` is not int, or the type of `file_prefix` or `file_suffix` is not str. TypeError: If the parameter type is incorrect.
RuntimeError: If the log_dir is not a normalized absolute path name.
Examples: Examples:
>>> # use in with statement to auto close >>> # use in with statement to auto close
@ -100,10 +98,11 @@ class SummaryRecord:
... summary_record.close() ... summary_record.close()
""" """
def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", network=None, max_file_size=None): def __init__(self, log_dir, file_prefix="events", file_suffix="_MS",
network=None, max_file_size=None, raise_exception=False):
self._closed, self._event_writer = False, None self._closed, self._event_writer = False, None
self._mode, self._data_pool = 'train', _dictlist() self._mode, self._data_pool = 'train', defaultdict(list)
Validator.check_str_by_regular(file_prefix) Validator.check_str_by_regular(file_prefix)
Validator.check_str_by_regular(file_suffix) Validator.check_str_by_regular(file_suffix)
@ -120,6 +119,8 @@ class SummaryRecord:
logger.warning("The 'max_file_size' should be greater than 0.") logger.warning("The 'max_file_size' should be greater than 0.")
max_file_size = None max_file_size = None
Validator.check_value_type(arg_name='raise_exception', arg_value=raise_exception, valid_types=bool)
self.prefix = file_prefix self.prefix = file_prefix
self.suffix = file_suffix self.suffix = file_suffix
self.network = network self.network = network
@ -127,16 +128,15 @@ class SummaryRecord:
# create the summary writer file # create the summary writer file
self.event_file_name = get_event_file_name(self.prefix, self.suffix) self.event_file_name = get_event_file_name(self.prefix, self.suffix)
try: self.full_file_name = os.path.join(self.log_path, self.event_file_name)
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
except Exception as ex:
raise RuntimeError(ex)
filename_dict = dict(summary=self.full_file_name,
lineage=get_event_file_name(self.prefix, '_lineage'),
explainer=get_event_file_name(self.prefix, '_explain'))
self._event_writer = WriterPool(log_dir, self._event_writer = WriterPool(log_dir,
max_file_size, max_file_size,
summary=self.full_file_name, raise_exception,
lineage=get_event_file_name(self.prefix, '_lineage'), **filename_dict)
explainer=get_event_file_name(self.prefix, '_explain'))
_get_summary_tensor_data() _get_summary_tensor_data()
atexit.register(self.close) atexit.register(self.close)
@ -195,8 +195,8 @@ class SummaryRecord:
- The data type of value should be a 'Explain' object when the plugin is 'explainer', - The data type of value should be a 'Explain' object when the plugin is 'explainer',
see mindspore/ccsrc/summary.proto. see mindspore/ccsrc/summary.proto.
Raises: Raises:
ValueError: When the name is not valid. ValueError: If the parameter value is invalid.
TypeError: When the value is not a Tensor. TypeError: If the parameter type is error.
Examples: Examples:
>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
@ -238,6 +238,10 @@ class SummaryRecord:
Returns: Returns:
bool, whether the record process is successful or not. bool, whether the record process is successful or not.
Raises:
TypeError: If the parameter type is error.
RuntimeError: If the disk space is insufficient.
Examples: Examples:
>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record: >>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
... summary_record.record(step=2) ... summary_record.record(step=2)
@ -245,11 +249,12 @@ class SummaryRecord:
True True
""" """
logger.debug("SummaryRecord step is %r.", step) logger.debug("SummaryRecord step is %r.", step)
Validator.check_value_type(arg_name='step', arg_value=step, valid_types=int)
Validator.check_value_type(arg_name='train_network', arg_value=train_network, valid_types=[Cell, type(None)])
if self._closed: if self._closed:
logger.error("The record writer is closed.") logger.error("The record writer is closed.")
return False return False
if not isinstance(step, int) or isinstance(step, bool):
raise ValueError("`step` should be int")
# Set the current summary of train step # Set the current summary of train step
if self.network is not None and not self.has_graph: if self.network is not None and not self.has_graph:
graph_proto = self.network.get_func_graph_proto() graph_proto = self.network.get_func_graph_proto()
@ -294,7 +299,7 @@ class SummaryRecord:
value['step'] = step value['step'] = step
return self._data_pool return self._data_pool
finally: finally:
self._data_pool = _dictlist() self._data_pool = defaultdict(list)
@property @property
def log_dir(self): def log_dir(self):

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -60,8 +60,8 @@ class BaseWriter:
self._max_file_size -= required_length self._max_file_size -= required_length
self.writer.Write(data) self.writer.Write(data)
else: else:
raise RuntimeError(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, " raise RuntimeWarning(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, "
f"but the '{self._filepath}' requires to write {required_length} bytes.") f"but the '{self._filepath}' requires to write {required_length} bytes.")
def flush(self): def flush(self):
"""Flush the writer.""" """Flush the writer."""

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """test_image_summary"""
@File : test_image_summary.py
@Author:
@Date : 2019-07-4
@Desc : test summary function
"""
import logging import logging
import os import os
import numpy as np import numpy as np
@ -70,23 +65,14 @@ def get_test_data(step):
# Test: call method on parse graph code # Test: call method on parse graph code
def test_image_summary_sample(): def test_image_summary_sample():
""" test_image_summary_sample """ """ test_image_summary_sample """
log.debug("begin test_image_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
for i in range(1, 5): for i in range(1, 5):
test_data = get_test_data(i) test_data = get_test_data(i)
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(i) test_writer.record(i)
test_writer.flush() test_writer.flush()
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_image_summary_sample")
class Net(nn.Cell): class Net(nn.Cell):
""" Net definition """ """ Net definition """
@ -175,23 +161,11 @@ class ImageSummaryCallback(Callback):
def test_image_summary_train(): def test_image_summary_train():
""" test_image_summary_train """ """ test_image_summary_train """
dataset = get_dataset() dataset = get_dataset()
log.debug("begin test_image_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
model = get_model() model = get_model()
callback = ImageSummaryCallback(test_writer) callback = ImageSummaryCallback(test_writer)
model.train(2, dataset, callbacks=[callback]) model.train(2, dataset, callbacks=[callback])
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_image_summary_sample")
def test_image_summary_data(): def test_image_summary_data():
""" test_image_summary_data """ """ test_image_summary_data """
@ -207,13 +181,6 @@ def test_image_summary_data():
test_data_list.append(dct) test_data_list.append(dct)
i += 1 i += 1
log.debug("begin test_image_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
_cache_summary_tensor_data(test_data_list) _cache_summary_tensor_data(test_data_list)
test_writer.record(1) test_writer.record(1)
log.debug("finished test_image_summary_sample")

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """Test summary."""
@File : test_summary.py
@Author:
@Date : 2019-07-4
@Desc : test summary function
"""
import logging
import os import os
import random import random
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -32,9 +27,6 @@ from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary
CUR_DIR = os.getcwd() CUR_DIR = os.getcwd()
SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/"
log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
def get_test_data(step): def get_test_data(step):
""" get_test_data """ """ get_test_data """
@ -58,26 +50,14 @@ def get_test_data(step):
return test_data_list return test_data_list
# Test 1: summary sample of scalar
def test_scalar_summary_sample(): def test_scalar_summary_sample():
""" test_scalar_summary_sample """ """ test_scalar_summary_sample """
log.debug("begin test_scalar_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the test data for summary for i in range(1, 5):
# step 2: create the Event
for i in range(1, 500):
test_data = get_test_data(i) test_data = get_test_data(i)
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(i) test_writer.record(i)
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_scalar_summary_sample")
def get_test_data_shape_1(step): def get_test_data_shape_1(step):
""" get_test_data_shape_1 """ """ get_test_data_shape_1 """
@ -104,23 +84,12 @@ def get_test_data_shape_1(step):
# Test: shape = (1,) # Test: shape = (1,)
def test_scalar_summary_sample_with_shape_1(): def test_scalar_summary_sample_with_shape_1():
""" test_scalar_summary_sample_with_shape_1 """ """ test_scalar_summary_sample_with_shape_1 """
log.debug("begin test_scalar_summary_sample_with_shape_1")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
for i in range(1, 100): for i in range(1, 100):
test_data = get_test_data_shape_1(i) test_data = get_test_data_shape_1(i)
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(i) test_writer.record(i)
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_scalar_summary_sample")
# Test: test with ge # Test: test with ge
class SummaryDemo(nn.Cell): class SummaryDemo(nn.Cell):
@ -143,13 +112,7 @@ class SummaryDemo(nn.Cell):
def test_scalar_summary_with_ge(): def test_scalar_summary_with_ge():
""" test_scalar_summary_with_ge """ """ test_scalar_summary_with_ge """
log.debug("begin test_scalar_summary_with_ge")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32))
net = SummaryDemo() net = SummaryDemo()
net.set_train() net.set_train()
@ -161,45 +124,17 @@ def test_scalar_summary_with_ge():
net(x, y) net(x, y)
test_writer.record(i) test_writer.record(i)
log.debug("finished test_scalar_summary_with_ge")
# test the problem of two consecutive use cases going wrong # test the problem of two consecutive use cases going wrong
def test_scalar_summary_with_ge_2(): def test_scalar_summary_with_ge_2():
""" test_scalar_summary_with_ge_2 """ """ test_scalar_summary_with_ge_2 """
log.debug("begin test_scalar_summary_with_ge_2")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer: with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32))
net = SummaryDemo() net = SummaryDemo()
net.set_train() net.set_train()
# step 2: create the Event
steps = 100 steps = 100
for i in range(1, steps): for i in range(1, steps):
x = Tensor(np.array([1.1]).astype(np.float32)) x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32)) y = Tensor(np.array([1.2]).astype(np.float32))
net(x, y) net(x, y)
test_writer.record(i) test_writer.record(i)
log.debug("finished test_scalar_summary_with_ge_2")
def test_validate():
with SummaryRecord(SUMMARY_DIR) as sr:
sr.record(1)
with pytest.raises(ValueError):
sr.record(False)
with pytest.raises(ValueError):
sr.record(2.0)
with pytest.raises(ValueError):
sr.record((1, 3))
with pytest.raises(ValueError):
sr.record([2, 3])
with pytest.raises(ValueError):
sr.record("str")
with pytest.raises(ValueError):
sr.record(sr)

@ -1,133 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_summary_abnormal_input.py
@Author:
@Date : 2019-08-5
@Desc : test summary function of abnormal input
"""
import logging
import os
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord
CUR_DIR = os.getcwd()
SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/"
log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
def get_test_data(step):
""" get_test_data """
test_data_list = []
tag1 = "x1[:Scalar]"
tag2 = "x2[:Scalar]"
np1 = np.array(step + 1).astype(np.float32)
np2 = np.array(step + 2).astype(np.float32)
dict1 = {}
dict1["name"] = tag1
dict1["data"] = Tensor(np1)
dict2 = {}
dict2["name"] = tag2
dict2["data"] = Tensor(np2)
test_data_list.append(dict1)
test_data_list.append(dict2)
return test_data_list
# Test: call method on parse graph code
def test_summaryrecord_input_null_string():
log.debug("begin test_summaryrecord_input_null_string")
# step 0: create the thread
try:
with SummaryRecord(""):
pass
except:
assert True
else:
assert False
log.debug("finished test_summaryrecord_input_null_string")
def test_summaryrecord_input_None():
log.debug("begin test_summaryrecord_input_None")
# step 0: create the thread
try:
with SummaryRecord(None):
pass
except:
assert True
else:
assert False
log.debug("finished test_summaryrecord_input_None")
def test_summaryrecord_input_relative_dir_1():
log.debug("begin test_summaryrecord_input_relative_dir_1")
# step 0: create the thread
try:
with SummaryRecord("./test_temp_summary_event_file/"):
pass
except:
assert False
else:
assert True
log.debug("finished test_summaryrecord_input_relative_dir_1")
def test_summaryrecord_input_relative_dir_2():
log.debug("begin test_summaryrecord_input_relative_dir_2")
# step 0: create the thread
try:
with SummaryRecord("../summary/"):
pass
except:
assert False
else:
assert True
log.debug("finished test_summaryrecord_input_relative_dir_2")
def test_summaryrecord_input_invalid_type_dir():
log.debug("begin test_summaryrecord_input_invalid_type_dir")
# step 0: create the thread
try:
with SummaryRecord(32):
pass
except:
assert True
else:
assert False
log.debug("finished test_summaryrecord_input_invalid_type_dir")
def test_mulit_layer_directory():
log.debug("begin test_mulit_layer_directory")
# step 0: create the thread
try:
with SummaryRecord("./test_temp_summary_event_file/test/t1/"):
pass
except:
assert False
else:
assert True
log.debug("finished test_mulit_layer_directory")

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -187,6 +187,14 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value) assert expected_msg == str(exc.value)
def test_params_with_histogram_regular_value_error(self):
"""Test histogram regular."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir, collect_specified_data={'histogram_regular': '*'})
assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(exc.value)
def test_params_with_collect_specified_data_unexpected_key(self): def test_params_with_collect_specified_data_unexpected_key(self):
"""Test the collect_specified_data parameter with unexpected key.""" """Test the collect_specified_data parameter with unexpected key."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
@ -260,7 +268,7 @@ class TestSummaryCollector:
cb_params.train_dataset_element = image_data cb_params.train_dataset_element = image_data
with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector:
summary_collector._collect_input_data(cb_params) summary_collector._collect_input_data(cb_params)
# Note Here need to asssert the result and expected data # Note Here need to assert the result and expected data
@mock.patch.object(SummaryRecord, 'add_value') @mock.patch.object(SummaryRecord, 'add_value')
def test_collect_dataset_graph_success(self, mock_add_value): def test_collect_dataset_graph_success(self, mock_add_value):
@ -296,7 +304,6 @@ class TestSummaryCollector:
assert summary_collector._is_parse_loss_success assert summary_collector._is_parse_loss_success
def test_get_optimizer_from_cb_params_success(self): def test_get_optimizer_from_cb_params_success(self):
"""Test get optimizer success from cb params.""" """Test get optimizer success from cb params."""
cb_params = _InternalCallbackParam() cb_params = _InternalCallbackParam()

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test_summary_abnormal_input"""
import os
import shutil
import tempfile
import numpy as np
import pytest
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord
def get_test_data(step):
""" get_test_data """
test_data_list = []
tag1 = "x1[:Scalar]"
tag2 = "x2[:Scalar]"
np1 = np.array(step + 1).astype(np.float32)
np2 = np.array(step + 2).astype(np.float32)
dict1 = {}
dict1["name"] = tag1
dict1["data"] = Tensor(np1)
dict2 = {}
dict2["name"] = tag2
dict2["data"] = Tensor(np2)
test_data_list.append(dict1)
test_data_list.append(dict2)
return test_data_list
class TestSummaryRecord:
"""Test SummaryRecord"""
def setup_class(self):
"""Run before test this class."""
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
def teardown_class(self):
"""Run after test this class."""
if os.path.exists(self.base_summary_dir):
shutil.rmtree(self.base_summary_dir)
@pytest.mark.parametrize("log_dir", ["", None, 32])
def test_log_dir_with_type_error(self, log_dir):
with pytest.raises(TypeError):
with SummaryRecord(log_dir):
pass
@pytest.mark.parametrize("raise_exception", ["", None, 32])
def test_raise_exception_with_type_error(self, raise_exception):
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception):
pass
assert "raise_exception" in str(exc.value)
@pytest.mark.parametrize("step", [False, 2.0, (1, 3), [2, 3], "str"])
def test_step_of_record_with_type_error(self, step):
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError):
with SummaryRecord(summary_dir) as sr:
sr.record(step)
Loading…
Cancel
Save