!10436 Support to control whether to throw runtime exceptions in SummaryRecord

From: @ouwenchang
Reviewed-by: 
Signed-off-by:
pull/10436/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1f06cd63f3

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -199,7 +199,7 @@ class ImageClassificationRunner:
"""
self._verify_data_n_settings(check_all=True)
with SummaryRecord(self._summary_dir) as summary:
with SummaryRecord(self._summary_dir, raise_exception=True) as summary:
print("Start running and writing......")
begin = time()

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -207,7 +207,9 @@ class SummaryCollector(Callback):
self._dataset_sink_mode = True
def __enter__(self):
self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size)
self._record = SummaryRecord(log_dir=self._summary_dir,
max_file_size=self._max_file_size,
raise_exception=False)
self._first_step, self._dataset_sink_mode = True, True
return self
@ -319,7 +321,14 @@ class SummaryCollector(Callback):
f'expect the follow keys: {list(self._DEFAULT_SPECIFIED_DATA.keys())}')
if 'histogram_regular' in specified_data:
check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None)))
regular = specified_data.get('histogram_regular')
check_value_type('histogram_regular', regular, (str, type(None)))
if isinstance(regular, str):
try:
re.match(regular, '')
except re.error as exc:
raise ValueError(f'For `collect_specified_data`, the value of `histogram_regular` '
f'is not a valid regular expression. Detail: {str(exc)}.')
bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'}
for item in bool_items:

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -58,14 +58,17 @@ class WriterPool(ctx.Process):
Args:
base_dir (str): The base directory to hold all the files.
max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes.
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
filedict (dict): The mapping from plugin to filename.
"""
def __init__(self, base_dir, max_file_size, **filedict) -> None:
def __init__(self, base_dir, max_file_size, raise_exception=False, **filedict) -> None:
super().__init__()
self._base_dir, self._filedict = base_dir, filedict
self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None
self._max_file_size = max_file_size
self._raise_exception = raise_exception
self.start()
def run(self):
@ -124,8 +127,14 @@ class WriterPool(ctx.Process):
for writer in self._writers[:]:
try:
writer.write(plugin, data)
except RuntimeError as e:
logger.warning(e.args[0])
except RuntimeError as exc:
logger.error(str(exc))
self._writers.remove(writer)
writer.close()
if self._raise_exception:
raise
except RuntimeWarning as exc:
logger.warning(str(exc))
self._writers.remove(writer)
writer.close()

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import threading
from collections import defaultdict
from mindspore import log as logger
from mindspore.nn import Cell
from ..._c_expression import Tensor
from ..._checkparam import Validator
@ -29,7 +30,7 @@ from ._explain_adapter import check_explain_proto
from ._writer_pool import WriterPool
# for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening.
# there are actually no any concurrences happening.
_summary_lock = threading.Lock()
# cache the summary data
_summary_tensor_cache = {}
@ -56,10 +57,6 @@ def _get_summary_tensor_data():
return data
def _dictlist():
return defaultdict(list)
class SummaryRecord:
"""
SummaryRecord is used to record the summary data and lineage data.
@ -80,12 +77,13 @@ class SummaryRecord:
file_prefix (str): The prefix of file. Default: "events".
file_suffix (str): The suffix of file. Default: "_MS".
network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
max_file_size (Optional[int]): The maximum size of each file that can be written to disk (in bytes). \
max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \
Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`.
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
Raises:
TypeError: If the type of `max_file_size` is not int, or the type of `file_prefix` or `file_suffix` is not str.
RuntimeError: If the log_dir is not a normalized absolute path name.
TypeError: If the parameter type is incorrect.
Examples:
>>> # use in with statement to auto close
@ -100,10 +98,11 @@ class SummaryRecord:
... summary_record.close()
"""
def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", network=None, max_file_size=None):
def __init__(self, log_dir, file_prefix="events", file_suffix="_MS",
network=None, max_file_size=None, raise_exception=False):
self._closed, self._event_writer = False, None
self._mode, self._data_pool = 'train', _dictlist()
self._mode, self._data_pool = 'train', defaultdict(list)
Validator.check_str_by_regular(file_prefix)
Validator.check_str_by_regular(file_suffix)
@ -120,6 +119,8 @@ class SummaryRecord:
logger.warning("The 'max_file_size' should be greater than 0.")
max_file_size = None
Validator.check_value_type(arg_name='raise_exception', arg_value=raise_exception, valid_types=bool)
self.prefix = file_prefix
self.suffix = file_suffix
self.network = network
@ -127,16 +128,15 @@ class SummaryRecord:
# create the summary writer file
self.event_file_name = get_event_file_name(self.prefix, self.suffix)
try:
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
except Exception as ex:
raise RuntimeError(ex)
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
filename_dict = dict(summary=self.full_file_name,
lineage=get_event_file_name(self.prefix, '_lineage'),
explainer=get_event_file_name(self.prefix, '_explain'))
self._event_writer = WriterPool(log_dir,
max_file_size,
summary=self.full_file_name,
lineage=get_event_file_name(self.prefix, '_lineage'),
explainer=get_event_file_name(self.prefix, '_explain'))
raise_exception,
**filename_dict)
_get_summary_tensor_data()
atexit.register(self.close)
@ -195,8 +195,8 @@ class SummaryRecord:
- The data type of value should be a 'Explain' object when the plugin is 'explainer',
see mindspore/ccsrc/summary.proto.
Raises:
ValueError: When the name is not valid.
TypeError: When the value is not a Tensor.
ValueError: If the parameter value is invalid.
TypeError: If the parameter type is error.
Examples:
>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
@ -238,6 +238,10 @@ class SummaryRecord:
Returns:
bool, whether the record process is successful or not.
Raises:
TypeError: If the parameter type is error.
RuntimeError: If the disk space is insufficient.
Examples:
>>> with SummaryRecord(log_dir="./summary_dir", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
... summary_record.record(step=2)
@ -245,11 +249,12 @@ class SummaryRecord:
True
"""
logger.debug("SummaryRecord step is %r.", step)
Validator.check_value_type(arg_name='step', arg_value=step, valid_types=int)
Validator.check_value_type(arg_name='train_network', arg_value=train_network, valid_types=[Cell, type(None)])
if self._closed:
logger.error("The record writer is closed.")
return False
if not isinstance(step, int) or isinstance(step, bool):
raise ValueError("`step` should be int")
# Set the current summary of train step
if self.network is not None and not self.has_graph:
graph_proto = self.network.get_func_graph_proto()
@ -294,7 +299,7 @@ class SummaryRecord:
value['step'] = step
return self._data_pool
finally:
self._data_pool = _dictlist()
self._data_pool = defaultdict(list)
@property
def log_dir(self):

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -60,8 +60,8 @@ class BaseWriter:
self._max_file_size -= required_length
self.writer.Write(data)
else:
raise RuntimeError(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, "
f"but the '{self._filepath}' requires to write {required_length} bytes.")
raise RuntimeWarning(f"'max_file_size' reached: There are {self._max_file_size} bytes remaining, "
f"but the '{self._filepath}' requires to write {required_length} bytes.")
def flush(self):
"""Flush the writer."""

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_image_summary.py
@Author:
@Date : 2019-07-4
@Desc : test summary function
"""
"""test_image_summary"""
import logging
import os
import numpy as np
@ -70,23 +65,14 @@ def get_test_data(step):
# Test: call method on parse graph code
def test_image_summary_sample():
""" test_image_summary_sample """
log.debug("begin test_image_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
for i in range(1, 5):
test_data = get_test_data(i)
_cache_summary_tensor_data(test_data)
test_writer.record(i)
test_writer.flush()
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_image_summary_sample")
class Net(nn.Cell):
""" Net definition """
@ -175,23 +161,11 @@ class ImageSummaryCallback(Callback):
def test_image_summary_train():
""" test_image_summary_train """
dataset = get_dataset()
log.debug("begin test_image_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
model = get_model()
callback = ImageSummaryCallback(test_writer)
model.train(2, dataset, callbacks=[callback])
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_image_summary_sample")
def test_image_summary_data():
""" test_image_summary_data """
@ -207,13 +181,6 @@ def test_image_summary_data():
test_data_list.append(dct)
i += 1
log.debug("begin test_image_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
_cache_summary_tensor_data(test_data_list)
test_writer.record(1)
log.debug("finished test_image_summary_sample")

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_summary.py
@Author:
@Date : 2019-07-4
@Desc : test summary function
"""
import logging
"""Test summary."""
import os
import random
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
@ -32,9 +27,6 @@ from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary
CUR_DIR = os.getcwd()
SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/"
log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
def get_test_data(step):
""" get_test_data """
@ -58,26 +50,14 @@ def get_test_data(step):
return test_data_list
# Test 1: summary sample of scalar
def test_scalar_summary_sample():
""" test_scalar_summary_sample """
log.debug("begin test_scalar_summary_sample")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
for i in range(1, 500):
for i in range(1, 5):
test_data = get_test_data(i)
_cache_summary_tensor_data(test_data)
test_writer.record(i)
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_scalar_summary_sample")
def get_test_data_shape_1(step):
""" get_test_data_shape_1 """
@ -104,23 +84,12 @@ def get_test_data_shape_1(step):
# Test: shape = (1,)
def test_scalar_summary_sample_with_shape_1():
""" test_scalar_summary_sample_with_shape_1 """
log.debug("begin test_scalar_summary_sample_with_shape_1")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
for i in range(1, 100):
test_data = get_test_data_shape_1(i)
_cache_summary_tensor_data(test_data)
test_writer.record(i)
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_scalar_summary_sample")
# Test: test with ge
class SummaryDemo(nn.Cell):
@ -143,13 +112,7 @@ class SummaryDemo(nn.Cell):
def test_scalar_summary_with_ge():
""" test_scalar_summary_with_ge """
log.debug("begin test_scalar_summary_with_ge")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32))
net = SummaryDemo()
net.set_train()
@ -161,45 +124,17 @@ def test_scalar_summary_with_ge():
net(x, y)
test_writer.record(i)
log.debug("finished test_scalar_summary_with_ge")
# test the problem of two consecutive use cases going wrong
def test_scalar_summary_with_ge_2():
""" test_scalar_summary_with_ge_2 """
log.debug("begin test_scalar_summary_with_ge_2")
# step 0: create the thread
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32))
net = SummaryDemo()
net.set_train()
# step 2: create the Event
steps = 100
for i in range(1, steps):
x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32))
net(x, y)
test_writer.record(i)
log.debug("finished test_scalar_summary_with_ge_2")
def test_validate():
with SummaryRecord(SUMMARY_DIR) as sr:
sr.record(1)
with pytest.raises(ValueError):
sr.record(False)
with pytest.raises(ValueError):
sr.record(2.0)
with pytest.raises(ValueError):
sr.record((1, 3))
with pytest.raises(ValueError):
sr.record([2, 3])
with pytest.raises(ValueError):
sr.record("str")
with pytest.raises(ValueError):
sr.record(sr)

@ -1,133 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_summary_abnormal_input.py
@Author:
@Date : 2019-08-5
@Desc : test summary function of abnormal input
"""
import logging
import os
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord
CUR_DIR = os.getcwd()
SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/"
log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
def get_test_data(step):
""" get_test_data """
test_data_list = []
tag1 = "x1[:Scalar]"
tag2 = "x2[:Scalar]"
np1 = np.array(step + 1).astype(np.float32)
np2 = np.array(step + 2).astype(np.float32)
dict1 = {}
dict1["name"] = tag1
dict1["data"] = Tensor(np1)
dict2 = {}
dict2["name"] = tag2
dict2["data"] = Tensor(np2)
test_data_list.append(dict1)
test_data_list.append(dict2)
return test_data_list
# Test: call method on parse graph code
def test_summaryrecord_input_null_string():
log.debug("begin test_summaryrecord_input_null_string")
# step 0: create the thread
try:
with SummaryRecord(""):
pass
except:
assert True
else:
assert False
log.debug("finished test_summaryrecord_input_null_string")
def test_summaryrecord_input_None():
log.debug("begin test_summaryrecord_input_None")
# step 0: create the thread
try:
with SummaryRecord(None):
pass
except:
assert True
else:
assert False
log.debug("finished test_summaryrecord_input_None")
def test_summaryrecord_input_relative_dir_1():
log.debug("begin test_summaryrecord_input_relative_dir_1")
# step 0: create the thread
try:
with SummaryRecord("./test_temp_summary_event_file/"):
pass
except:
assert False
else:
assert True
log.debug("finished test_summaryrecord_input_relative_dir_1")
def test_summaryrecord_input_relative_dir_2():
log.debug("begin test_summaryrecord_input_relative_dir_2")
# step 0: create the thread
try:
with SummaryRecord("../summary/"):
pass
except:
assert False
else:
assert True
log.debug("finished test_summaryrecord_input_relative_dir_2")
def test_summaryrecord_input_invalid_type_dir():
log.debug("begin test_summaryrecord_input_invalid_type_dir")
# step 0: create the thread
try:
with SummaryRecord(32):
pass
except:
assert True
else:
assert False
log.debug("finished test_summaryrecord_input_invalid_type_dir")
def test_mulit_layer_directory():
log.debug("begin test_mulit_layer_directory")
# step 0: create the thread
try:
with SummaryRecord("./test_temp_summary_event_file/test/t1/"):
pass
except:
assert False
else:
assert True
log.debug("finished test_mulit_layer_directory")

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -187,6 +187,14 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value)
def test_params_with_histogram_regular_value_error(self):
"""Test histogram regular."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir, collect_specified_data={'histogram_regular': '*'})
assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(exc.value)
def test_params_with_collect_specified_data_unexpected_key(self):
"""Test the collect_specified_data parameter with unexpected key."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
@ -260,7 +268,7 @@ class TestSummaryCollector:
cb_params.train_dataset_element = image_data
with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector:
summary_collector._collect_input_data(cb_params)
# Note Here need to asssert the result and expected data
# Note Here need to assert the result and expected data
@mock.patch.object(SummaryRecord, 'add_value')
def test_collect_dataset_graph_success(self, mock_add_value):
@ -296,7 +304,6 @@ class TestSummaryCollector:
assert summary_collector._is_parse_loss_success
def test_get_optimizer_from_cb_params_success(self):
"""Test get optimizer success from cb params."""
cb_params = _InternalCallbackParam()

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test_summary_abnormal_input"""
import os
import shutil
import tempfile
import numpy as np
import pytest
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord
def get_test_data(step):
""" get_test_data """
test_data_list = []
tag1 = "x1[:Scalar]"
tag2 = "x2[:Scalar]"
np1 = np.array(step + 1).astype(np.float32)
np2 = np.array(step + 2).astype(np.float32)
dict1 = {}
dict1["name"] = tag1
dict1["data"] = Tensor(np1)
dict2 = {}
dict2["name"] = tag2
dict2["data"] = Tensor(np2)
test_data_list.append(dict1)
test_data_list.append(dict2)
return test_data_list
class TestSummaryRecord:
"""Test SummaryRecord"""
def setup_class(self):
"""Run before test this class."""
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
def teardown_class(self):
"""Run after test this class."""
if os.path.exists(self.base_summary_dir):
shutil.rmtree(self.base_summary_dir)
@pytest.mark.parametrize("log_dir", ["", None, 32])
def test_log_dir_with_type_error(self, log_dir):
with pytest.raises(TypeError):
with SummaryRecord(log_dir):
pass
@pytest.mark.parametrize("raise_exception", ["", None, 32])
def test_raise_exception_with_type_error(self, raise_exception):
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception):
pass
assert "raise_exception" in str(exc.value)
@pytest.mark.parametrize("step", [False, 2.0, (1, 3), [2, 3], "str"])
def test_step_of_record_with_type_error(self, step):
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError):
with SummaryRecord(summary_dir) as sr:
sr.record(step)
Loading…
Cancel
Save