1. Support explainer raise an RuntimeError exception 2. fix the ut of SummaryRecordpull/10436/head
parent
280e127f59
commit
06be546b52
@ -1,133 +0,0 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
@File : test_summary_abnormal_input.py
|
||||
@Author:
|
||||
@Date : 2019-08-5
|
||||
@Desc : test summary function of abnormal input
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/"
|
||||
|
||||
log = logging.getLogger("test")
|
||||
log.setLevel(level=logging.ERROR)
|
||||
|
||||
|
||||
def get_test_data(step):
|
||||
""" get_test_data """
|
||||
test_data_list = []
|
||||
tag1 = "x1[:Scalar]"
|
||||
tag2 = "x2[:Scalar]"
|
||||
np1 = np.array(step + 1).astype(np.float32)
|
||||
np2 = np.array(step + 2).astype(np.float32)
|
||||
|
||||
dict1 = {}
|
||||
dict1["name"] = tag1
|
||||
dict1["data"] = Tensor(np1)
|
||||
|
||||
dict2 = {}
|
||||
dict2["name"] = tag2
|
||||
dict2["data"] = Tensor(np2)
|
||||
|
||||
test_data_list.append(dict1)
|
||||
test_data_list.append(dict2)
|
||||
|
||||
return test_data_list
|
||||
|
||||
|
||||
# Test: call method on parse graph code
|
||||
def test_summaryrecord_input_null_string():
|
||||
log.debug("begin test_summaryrecord_input_null_string")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
with SummaryRecord(""):
|
||||
pass
|
||||
except:
|
||||
assert True
|
||||
else:
|
||||
assert False
|
||||
log.debug("finished test_summaryrecord_input_null_string")
|
||||
|
||||
|
||||
def test_summaryrecord_input_None():
|
||||
log.debug("begin test_summaryrecord_input_None")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
with SummaryRecord(None):
|
||||
pass
|
||||
except:
|
||||
assert True
|
||||
else:
|
||||
assert False
|
||||
log.debug("finished test_summaryrecord_input_None")
|
||||
|
||||
|
||||
def test_summaryrecord_input_relative_dir_1():
|
||||
log.debug("begin test_summaryrecord_input_relative_dir_1")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
with SummaryRecord("./test_temp_summary_event_file/"):
|
||||
pass
|
||||
except:
|
||||
assert False
|
||||
else:
|
||||
assert True
|
||||
log.debug("finished test_summaryrecord_input_relative_dir_1")
|
||||
|
||||
|
||||
def test_summaryrecord_input_relative_dir_2():
|
||||
log.debug("begin test_summaryrecord_input_relative_dir_2")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
with SummaryRecord("../summary/"):
|
||||
pass
|
||||
except:
|
||||
assert False
|
||||
else:
|
||||
assert True
|
||||
log.debug("finished test_summaryrecord_input_relative_dir_2")
|
||||
|
||||
|
||||
def test_summaryrecord_input_invalid_type_dir():
|
||||
log.debug("begin test_summaryrecord_input_invalid_type_dir")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
with SummaryRecord(32):
|
||||
pass
|
||||
except:
|
||||
assert True
|
||||
else:
|
||||
assert False
|
||||
log.debug("finished test_summaryrecord_input_invalid_type_dir")
|
||||
|
||||
|
||||
def test_mulit_layer_directory():
|
||||
log.debug("begin test_mulit_layer_directory")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
with SummaryRecord("./test_temp_summary_event_file/test/t1/"):
|
||||
pass
|
||||
except:
|
||||
assert False
|
||||
else:
|
||||
assert True
|
||||
log.debug("finished test_mulit_layer_directory")
|
@ -0,0 +1,80 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test_summary_abnormal_input"""
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
|
||||
|
||||
def get_test_data(step):
|
||||
""" get_test_data """
|
||||
test_data_list = []
|
||||
tag1 = "x1[:Scalar]"
|
||||
tag2 = "x2[:Scalar]"
|
||||
np1 = np.array(step + 1).astype(np.float32)
|
||||
np2 = np.array(step + 2).astype(np.float32)
|
||||
|
||||
dict1 = {}
|
||||
dict1["name"] = tag1
|
||||
dict1["data"] = Tensor(np1)
|
||||
|
||||
dict2 = {}
|
||||
dict2["name"] = tag2
|
||||
dict2["data"] = Tensor(np2)
|
||||
|
||||
test_data_list.append(dict1)
|
||||
test_data_list.append(dict2)
|
||||
|
||||
return test_data_list
|
||||
|
||||
|
||||
class TestSummaryRecord:
|
||||
"""Test SummaryRecord"""
|
||||
def setup_class(self):
|
||||
"""Run before test this class."""
|
||||
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
|
||||
|
||||
def teardown_class(self):
|
||||
"""Run after test this class."""
|
||||
if os.path.exists(self.base_summary_dir):
|
||||
shutil.rmtree(self.base_summary_dir)
|
||||
|
||||
@pytest.mark.parametrize("log_dir", ["", None, 32])
|
||||
def test_log_dir_with_type_error(self, log_dir):
|
||||
with pytest.raises(TypeError):
|
||||
with SummaryRecord(log_dir):
|
||||
pass
|
||||
|
||||
@pytest.mark.parametrize("raise_exception", ["", None, 32])
|
||||
def test_raise_exception_with_type_error(self, raise_exception):
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception):
|
||||
pass
|
||||
|
||||
assert "raise_exception" in str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("step", [False, 2.0, (1, 3), [2, 3], "str"])
|
||||
def test_step_of_record_with_type_error(self, step):
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError):
|
||||
with SummaryRecord(summary_dir) as sr:
|
||||
sr.record(step)
|
Loading…
Reference in new issue