!2147 Add a callback named SummaryCollector and delete SummaryStep callback
Merge pull request !2147 from ougongchang/masterpull/2147/MERGE
commit
9514b52a23
@ -0,0 +1,128 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define dataset graph related operations."""
|
||||
import json
|
||||
from importlib import import_module
|
||||
|
||||
from mindspore.train import lineage_pb2
|
||||
|
||||
|
||||
class DatasetGraph:
|
||||
"""Handle the data graph and packages it into binary data."""
|
||||
def package_dataset_graph(self, dataset):
|
||||
"""
|
||||
packages dataset graph into binary data
|
||||
|
||||
Args:
|
||||
dataset (MindData): refer to MindDataset
|
||||
|
||||
Returns:
|
||||
DatasetGraph, a object of lineage_pb2.DatasetGraph.
|
||||
"""
|
||||
dataset_package = import_module('mindspore.dataset')
|
||||
dataset_dict = dataset_package.serialize(dataset)
|
||||
json_str = json.dumps(dataset_dict, indent=2)
|
||||
dataset_dict = json.loads(json_str)
|
||||
dataset_graph_proto = lineage_pb2.DatasetGraph()
|
||||
if "children" in dataset_dict:
|
||||
children = dataset_dict.pop("children")
|
||||
if children:
|
||||
self._package_children(children=children, message=dataset_graph_proto)
|
||||
self._package_current_dataset(operation=dataset_dict, message=dataset_graph_proto)
|
||||
return dataset_graph_proto
|
||||
|
||||
def _package_children(self, children, message):
|
||||
"""
|
||||
Package children in dataset operation.
|
||||
|
||||
Args:
|
||||
children (list[dict]): Child operations.
|
||||
message (DatasetGraph): Children proto message.
|
||||
"""
|
||||
for child in children:
|
||||
if child:
|
||||
child_graph_message = getattr(message, "children").add()
|
||||
grandson = child.pop("children")
|
||||
if grandson:
|
||||
self._package_children(children=grandson, message=child_graph_message)
|
||||
# package other parameters
|
||||
self._package_current_dataset(operation=child, message=child_graph_message)
|
||||
|
||||
def _package_current_dataset(self, operation, message):
|
||||
"""
|
||||
Package operation parameters in event message.
|
||||
|
||||
Args:
|
||||
operation (dict): Operation dict.
|
||||
message (Operation): Operation proto message.
|
||||
"""
|
||||
for key, value in operation.items():
|
||||
if value and key == "operations":
|
||||
for operator in value:
|
||||
self._package_enhancement_operation(
|
||||
operator,
|
||||
message.operations.add()
|
||||
)
|
||||
elif value and key == "sampler":
|
||||
self._package_enhancement_operation(
|
||||
value,
|
||||
message.sampler
|
||||
)
|
||||
else:
|
||||
self._package_parameter(key, value, message.parameter)
|
||||
|
||||
def _package_enhancement_operation(self, operation, message):
|
||||
"""
|
||||
Package enhancement operation in MapDataset.
|
||||
|
||||
Args:
|
||||
operation (dict): Enhancement operation.
|
||||
message (Operation): Enhancement operation proto message.
|
||||
"""
|
||||
for key, value in operation.items():
|
||||
if isinstance(value, list):
|
||||
if all(isinstance(ele, int) for ele in value):
|
||||
message.size.extend(value)
|
||||
else:
|
||||
message.weights.extend(value)
|
||||
else:
|
||||
self._package_parameter(key, value, message.operationParam)
|
||||
|
||||
@staticmethod
|
||||
def _package_parameter(key, value, message):
|
||||
"""
|
||||
Package parameters in operation.
|
||||
|
||||
Args:
|
||||
key (str): Operation name.
|
||||
value (Union[str, bool, int, float, list, None]): Operation args.
|
||||
message (OperationParameter): Operation proto message.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
message.mapStr[key] = value
|
||||
elif isinstance(value, bool):
|
||||
message.mapBool[key] = value
|
||||
elif isinstance(value, int):
|
||||
message.mapInt[key] = value
|
||||
elif isinstance(value, float):
|
||||
message.mapDouble[key] = value
|
||||
elif isinstance(value, list) and key != "operations":
|
||||
if value:
|
||||
replace_value_list = list(map(lambda x: "" if x is None else x, value))
|
||||
message.mapStrList[key].strValue.extend(replace_value_list)
|
||||
elif value is None:
|
||||
message.mapStr[key] = "None"
|
||||
else:
|
||||
raise ValueError(f"Parameter {key} is not supported in event package.")
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,56 +0,0 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SummaryStep Callback class."""
|
||||
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
class SummaryStep(Callback):
|
||||
"""
|
||||
The summary callback class.
|
||||
|
||||
Args:
|
||||
summary (Object): Summary recode object.
|
||||
flush_step (int): Number of interval steps to execute. Default: 10.
|
||||
"""
|
||||
|
||||
def __init__(self, summary, flush_step=10):
|
||||
super(SummaryStep, self).__init__()
|
||||
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
|
||||
raise ValueError("`flush_step` should be int and greater than 0")
|
||||
self._summary = summary
|
||||
self._flush_step = flush_step
|
||||
|
||||
def __enter__(self):
|
||||
self._summary.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._summary.__exit__(*err)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save summary.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self._flush_step == 0:
|
||||
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
|
||||
|
||||
@property
|
||||
def summary_file_name(self):
|
||||
return self._summary.full_file_name
|
||||
@ -0,0 +1,43 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Summary's enumeration file."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BaseEnum(Enum):
|
||||
"""The base enum class."""
|
||||
|
||||
@classmethod
|
||||
def to_list(cls):
|
||||
"""Converts the enumeration into a list."""
|
||||
return [member.value for member in cls.__members__.values()]
|
||||
|
||||
|
||||
class PluginEnum(BaseEnum):
|
||||
"""The list of plugins currently supported by the summary."""
|
||||
GRAPH = 'graph'
|
||||
SCALAR = 'scalar'
|
||||
IMAGE = 'image'
|
||||
TENSOR = 'tensor'
|
||||
HISTOGRAM = 'histogram'
|
||||
TRAIN_LINEAGE = 'train_lineage'
|
||||
EVAL_LINEAGE = 'eval_lineage'
|
||||
DATASET_GRAPH = 'dataset_graph'
|
||||
|
||||
|
||||
class ModeEnum(BaseEnum):
|
||||
"""The modes currently supported by the summary."""
|
||||
TRAIN = 'train'
|
||||
EVAL = 'eval'
|
||||
@ -0,0 +1,184 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test the exception parameter scenario for summary collector."""
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from mindspore.train.callback import SummaryCollector
|
||||
|
||||
|
||||
class TestSummaryCollector:
|
||||
"""Test the exception parameter for summary collector."""
|
||||
base_summary_dir = ''
|
||||
|
||||
def setup_class(self):
|
||||
"""Run before test this class."""
|
||||
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
|
||||
|
||||
def teardown_class(self):
|
||||
"""Run after test this class."""
|
||||
if os.path.exists(self.base_summary_dir):
|
||||
shutil.rmtree(self.base_summary_dir)
|
||||
|
||||
@pytest.mark.parametrize("summary_dir", [1234, None, True, ''])
|
||||
def test_params_with_summary_dir_value_error(self, summary_dir):
|
||||
"""Test the exception scenario for summary dir."""
|
||||
if isinstance(summary_dir, str):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \
|
||||
'but got empty string.'
|
||||
else:
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
|
||||
|
||||
def test_params_with_summary_dir_not_dir(self):
|
||||
"""Test the given summary dir parameter is not a directory."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
summary_file = os.path.join(summary_dir, 'temp_file.txt')
|
||||
with open(summary_file, 'w') as file_handle:
|
||||
file_handle.write('temp')
|
||||
print(os.path.isfile(summary_file))
|
||||
with pytest.raises(NotADirectoryError):
|
||||
SummaryCollector(summary_dir=summary_file)
|
||||
|
||||
@pytest.mark.parametrize("collect_freq", [None, 0, 0.01])
|
||||
def test_params_with_collect_freq_exception(self, collect_freq):
|
||||
"""Test the exception scenario for collect freq."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
if isinstance(collect_freq, int):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
|
||||
expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.'
|
||||
assert expected_msg == str(exc.value)
|
||||
else:
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
|
||||
expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \
|
||||
f'bug got {type(collect_freq).__name__}.'
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("action", [None, 123, '', '123'])
|
||||
def test_params_with_action_exception(self, action):
|
||||
"""Test the exception scenario for action."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir, keep_default_action=action)
|
||||
expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \
|
||||
f"bug got {type(action).__name__}."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("collect_specified_data", [123])
|
||||
def test_params_with_collect_specified_data_type_error(self, collect_specified_data):
|
||||
"""Test type error scenario for collect specified data param."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
|
||||
|
||||
expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \
|
||||
f"bug got {type(collect_specified_data).__name__}."
|
||||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("collect_specified_data", [
|
||||
{
|
||||
123: 123
|
||||
},
|
||||
{
|
||||
None: True
|
||||
}
|
||||
])
|
||||
def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data):
|
||||
"""Test the key of collect specified data param."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
|
||||
|
||||
param_name = list(collect_specified_data)[0]
|
||||
expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \
|
||||
f"bug got {type(param_name).__name__}."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("collect_specified_data", [
|
||||
{
|
||||
'collect_metric': None
|
||||
},
|
||||
{
|
||||
'collect_graph': 123
|
||||
},
|
||||
{
|
||||
'histogram_regular': 123
|
||||
},
|
||||
])
|
||||
def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data):
|
||||
"""Test the value of collect specified data param."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
|
||||
|
||||
param_name = list(collect_specified_data)[0]
|
||||
param_value = collect_specified_data[param_name]
|
||||
expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']"
|
||||
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
|
||||
f'bug got {type(param_value).__name__}.'
|
||||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
def test_params_with_collect_specified_data_unexpected_key(self):
|
||||
"""Test the collect_specified_data parameter with unexpected key."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
data = {'unexpected_key': True}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=data)
|
||||
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("custom_lineage_data", [
|
||||
123,
|
||||
{
|
||||
'custom': {}
|
||||
},
|
||||
{
|
||||
'custom': None
|
||||
},
|
||||
{
|
||||
123: 'custom'
|
||||
}
|
||||
])
|
||||
def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data):
|
||||
"""Test the custom lineage data parameter type error."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data)
|
||||
|
||||
if not isinstance(custom_lineage_data, dict):
|
||||
expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \
|
||||
f"bug got {type(custom_lineage_data).__name__}."
|
||||
else:
|
||||
param_name = list(custom_lineage_data)[0]
|
||||
param_value = custom_lineage_data[param_name]
|
||||
if not isinstance(param_name, str):
|
||||
arg_name = f'custom_lineage_data -> {param_name}'
|
||||
expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \
|
||||
f'bug got {type(param_name).__name__}.'
|
||||
else:
|
||||
arg_name = f'the value of custom_lineage_data -> {param_name}'
|
||||
expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \
|
||||
f'bug got {type(param_value).__name__}.'
|
||||
|
||||
assert expected_msg == str(exc.value)
|
||||
Loading…
Reference in new issue