I added a SummaryCollector to help users automatically collect information such as the network, loss, learning rate and so on, making it easier to collect this information. It also can collect train lineage and eval lineage information which is collected by TrainLineage Callback and EvalLineage Callback in MindInsight. I also add some UT for SummaryCollect to keep the code correct.pull/2147/head
							parent
							
								
									c55b81e94f
								
							
						
					
					
						commit
						939cd29d7e
					
				| @ -0,0 +1,128 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """Define dataset graph related operations.""" | ||||
| import json | ||||
| from importlib import import_module | ||||
| 
 | ||||
| from mindspore.train import lineage_pb2 | ||||
| 
 | ||||
| 
 | ||||
| class DatasetGraph: | ||||
|     """Handle the data graph and packages it into binary data.""" | ||||
|     def package_dataset_graph(self, dataset): | ||||
|         """ | ||||
|         packages dataset graph into binary data | ||||
| 
 | ||||
|         Args: | ||||
|             dataset (MindData): refer to MindDataset | ||||
| 
 | ||||
|         Returns: | ||||
|             DatasetGraph, a object of lineage_pb2.DatasetGraph. | ||||
|         """ | ||||
|         dataset_package = import_module('mindspore.dataset') | ||||
|         dataset_dict = dataset_package.serialize(dataset) | ||||
|         json_str = json.dumps(dataset_dict, indent=2) | ||||
|         dataset_dict = json.loads(json_str) | ||||
|         dataset_graph_proto = lineage_pb2.DatasetGraph() | ||||
|         if "children" in dataset_dict: | ||||
|             children = dataset_dict.pop("children") | ||||
|             if children: | ||||
|                 self._package_children(children=children, message=dataset_graph_proto) | ||||
|             self._package_current_dataset(operation=dataset_dict, message=dataset_graph_proto) | ||||
|         return dataset_graph_proto | ||||
| 
 | ||||
|     def _package_children(self, children, message): | ||||
|         """ | ||||
|         Package children in dataset operation. | ||||
| 
 | ||||
|         Args: | ||||
|             children (list[dict]): Child operations. | ||||
|             message (DatasetGraph): Children proto message. | ||||
|         """ | ||||
|         for child in children: | ||||
|             if child: | ||||
|                 child_graph_message = getattr(message, "children").add() | ||||
|                 grandson = child.pop("children") | ||||
|                 if grandson: | ||||
|                     self._package_children(children=grandson, message=child_graph_message) | ||||
|                 # package other parameters | ||||
|                 self._package_current_dataset(operation=child, message=child_graph_message) | ||||
| 
 | ||||
|     def _package_current_dataset(self, operation, message): | ||||
|         """ | ||||
|         Package operation parameters in event message. | ||||
| 
 | ||||
|         Args: | ||||
|             operation (dict): Operation dict. | ||||
|             message (Operation): Operation proto message. | ||||
|         """ | ||||
|         for key, value in operation.items(): | ||||
|             if value and key == "operations": | ||||
|                 for operator in value: | ||||
|                     self._package_enhancement_operation( | ||||
|                         operator, | ||||
|                         message.operations.add() | ||||
|                     ) | ||||
|             elif value and key == "sampler": | ||||
|                 self._package_enhancement_operation( | ||||
|                     value, | ||||
|                     message.sampler | ||||
|                 ) | ||||
|             else: | ||||
|                 self._package_parameter(key, value, message.parameter) | ||||
| 
 | ||||
|     def _package_enhancement_operation(self, operation, message): | ||||
|         """ | ||||
|         Package enhancement operation in MapDataset. | ||||
| 
 | ||||
|         Args: | ||||
|             operation (dict): Enhancement operation. | ||||
|             message (Operation): Enhancement operation proto message. | ||||
|         """ | ||||
|         for key, value in operation.items(): | ||||
|             if isinstance(value, list): | ||||
|                 if all(isinstance(ele, int) for ele in value): | ||||
|                     message.size.extend(value) | ||||
|                 else: | ||||
|                     message.weights.extend(value) | ||||
|             else: | ||||
|                 self._package_parameter(key, value, message.operationParam) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _package_parameter(key, value, message): | ||||
|         """ | ||||
|         Package parameters in operation. | ||||
| 
 | ||||
|         Args: | ||||
|             key (str): Operation name. | ||||
|             value (Union[str, bool, int, float, list, None]): Operation args. | ||||
|             message (OperationParameter): Operation proto message. | ||||
|         """ | ||||
|         if isinstance(value, str): | ||||
|             message.mapStr[key] = value | ||||
|         elif isinstance(value, bool): | ||||
|             message.mapBool[key] = value | ||||
|         elif isinstance(value, int): | ||||
|             message.mapInt[key] = value | ||||
|         elif isinstance(value, float): | ||||
|             message.mapDouble[key] = value | ||||
|         elif isinstance(value, list) and key != "operations": | ||||
|             if value: | ||||
|                 replace_value_list = list(map(lambda x: "" if x is None else x, value)) | ||||
|                 message.mapStrList[key].strValue.extend(replace_value_list) | ||||
|         elif value is None: | ||||
|             message.mapStr[key] = "None" | ||||
|         else: | ||||
|             raise ValueError(f"Parameter {key} is not supported in event package.") | ||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								| @ -1,56 +0,0 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """SummaryStep Callback class.""" | ||||
| 
 | ||||
| from ._callback import Callback | ||||
| 
 | ||||
| 
 | ||||
| class SummaryStep(Callback): | ||||
|     """ | ||||
|     The summary callback class. | ||||
| 
 | ||||
|     Args: | ||||
|         summary (Object): Summary recode object. | ||||
|         flush_step (int): Number of interval steps to execute. Default: 10. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, summary, flush_step=10): | ||||
|         super(SummaryStep, self).__init__() | ||||
|         if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: | ||||
|             raise ValueError("`flush_step` should be int and greater than 0") | ||||
|         self._summary = summary | ||||
|         self._flush_step = flush_step | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|         self._summary.__enter__() | ||||
|         return self | ||||
| 
 | ||||
|     def __exit__(self, *err): | ||||
|         return self._summary.__exit__(*err) | ||||
| 
 | ||||
|     def step_end(self, run_context): | ||||
|         """ | ||||
|         Save summary. | ||||
| 
 | ||||
|         Args: | ||||
|             run_context (RunContext): Context of the train running. | ||||
|         """ | ||||
|         cb_params = run_context.original_args() | ||||
|         if cb_params.cur_step_num % self._flush_step == 0: | ||||
|             self._summary.record(cb_params.cur_step_num, cb_params.train_network) | ||||
| 
 | ||||
|     @property | ||||
|     def summary_file_name(self): | ||||
|         return self._summary.full_file_name | ||||
| @ -0,0 +1,43 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """Summary's enumeration file.""" | ||||
| from enum import Enum | ||||
| 
 | ||||
| 
 | ||||
| class BaseEnum(Enum): | ||||
|     """The base enum class.""" | ||||
| 
 | ||||
|     @classmethod | ||||
|     def to_list(cls): | ||||
|         """Converts the enumeration into a list.""" | ||||
|         return [member.value for member in cls.__members__.values()] | ||||
| 
 | ||||
| 
 | ||||
| class PluginEnum(BaseEnum): | ||||
|     """The list of plugins currently supported by the summary.""" | ||||
|     GRAPH = 'graph' | ||||
|     SCALAR = 'scalar' | ||||
|     IMAGE = 'image' | ||||
|     TENSOR = 'tensor' | ||||
|     HISTOGRAM = 'histogram' | ||||
|     TRAIN_LINEAGE = 'train_lineage' | ||||
|     EVAL_LINEAGE = 'eval_lineage' | ||||
|     DATASET_GRAPH = 'dataset_graph' | ||||
| 
 | ||||
| 
 | ||||
| class ModeEnum(BaseEnum): | ||||
|     """The modes currently supported by the summary.""" | ||||
|     TRAIN = 'train' | ||||
|     EVAL = 'eval' | ||||
| @ -0,0 +1,184 @@ | ||||
| # Copyright 2020 Huawei Technologies Co., Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================ | ||||
| """Test the exception parameter scenario for summary collector.""" | ||||
| import os | ||||
| import tempfile | ||||
| import shutil | ||||
| import pytest | ||||
| 
 | ||||
| from mindspore.train.callback import SummaryCollector | ||||
| 
 | ||||
| 
 | ||||
| class TestSummaryCollector: | ||||
|     """Test the exception parameter for summary collector.""" | ||||
|     base_summary_dir = '' | ||||
| 
 | ||||
|     def setup_class(self): | ||||
|         """Run before test this class.""" | ||||
|         self.base_summary_dir = tempfile.mkdtemp(suffix='summary') | ||||
| 
 | ||||
|     def teardown_class(self): | ||||
|         """Run after test this class.""" | ||||
|         if os.path.exists(self.base_summary_dir): | ||||
|             shutil.rmtree(self.base_summary_dir) | ||||
| 
 | ||||
|     @pytest.mark.parametrize("summary_dir", [1234, None, True, '']) | ||||
|     def test_params_with_summary_dir_value_error(self, summary_dir): | ||||
|         """Test the exception scenario for summary dir.""" | ||||
|         if isinstance(summary_dir, str): | ||||
|             with pytest.raises(ValueError) as exc: | ||||
|                 SummaryCollector(summary_dir=summary_dir) | ||||
|             assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \ | ||||
|                                      'but got empty string.' | ||||
|         else: | ||||
|             with pytest.raises(TypeError) as exc: | ||||
|                 SummaryCollector(summary_dir=summary_dir) | ||||
|             assert 'For `summary_dir` the type should be a valid type' in str(exc.value) | ||||
| 
 | ||||
|     def test_params_with_summary_dir_not_dir(self): | ||||
|         """Test the given summary dir parameter is not a directory.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         summary_file = os.path.join(summary_dir, 'temp_file.txt') | ||||
|         with open(summary_file, 'w') as file_handle: | ||||
|             file_handle.write('temp') | ||||
|         print(os.path.isfile(summary_file)) | ||||
|         with pytest.raises(NotADirectoryError): | ||||
|             SummaryCollector(summary_dir=summary_file) | ||||
| 
 | ||||
|     @pytest.mark.parametrize("collect_freq", [None, 0, 0.01]) | ||||
|     def test_params_with_collect_freq_exception(self, collect_freq): | ||||
|         """Test the exception scenario for collect freq.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         if isinstance(collect_freq, int): | ||||
|             with pytest.raises(ValueError) as exc: | ||||
|                 SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq) | ||||
|             expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.' | ||||
|             assert expected_msg == str(exc.value) | ||||
|         else: | ||||
|             with pytest.raises(TypeError) as exc: | ||||
|                 SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq) | ||||
|             expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \ | ||||
|                            f'bug got {type(collect_freq).__name__}.' | ||||
|             assert expected_msg == str(exc.value) | ||||
| 
 | ||||
|     @pytest.mark.parametrize("action", [None, 123, '', '123']) | ||||
|     def test_params_with_action_exception(self, action): | ||||
|         """Test the exception scenario for action.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         with pytest.raises(TypeError) as exc: | ||||
|             SummaryCollector(summary_dir=summary_dir, keep_default_action=action) | ||||
|         expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \ | ||||
|                        f"bug got {type(action).__name__}." | ||||
|         assert expected_msg == str(exc.value) | ||||
| 
 | ||||
|     @pytest.mark.parametrize("collect_specified_data", [123]) | ||||
|     def test_params_with_collect_specified_data_type_error(self, collect_specified_data): | ||||
|         """Test type error scenario for collect specified data param.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         with pytest.raises(TypeError) as exc: | ||||
|             SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) | ||||
| 
 | ||||
|         expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \ | ||||
|                        f"bug got {type(collect_specified_data).__name__}." | ||||
| 
 | ||||
|         assert expected_msg == str(exc.value) | ||||
| 
 | ||||
|     @pytest.mark.parametrize("collect_specified_data", [ | ||||
|         { | ||||
|             123: 123 | ||||
|         }, | ||||
|         { | ||||
|             None: True | ||||
|         } | ||||
|     ]) | ||||
|     def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data): | ||||
|         """Test the key of collect specified data param.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         with pytest.raises(TypeError) as exc: | ||||
|             SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) | ||||
| 
 | ||||
|         param_name = list(collect_specified_data)[0] | ||||
|         expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \ | ||||
|                        f"bug got {type(param_name).__name__}." | ||||
|         assert expected_msg == str(exc.value) | ||||
| 
 | ||||
|     @pytest.mark.parametrize("collect_specified_data", [ | ||||
|         { | ||||
|             'collect_metric': None | ||||
|         }, | ||||
|         { | ||||
|             'collect_graph': 123 | ||||
|         }, | ||||
|         { | ||||
|             'histogram_regular': 123 | ||||
|         }, | ||||
|     ]) | ||||
|     def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data): | ||||
|         """Test the value of collect specified data param.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         with pytest.raises(TypeError) as exc: | ||||
|             SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) | ||||
| 
 | ||||
|         param_name = list(collect_specified_data)[0] | ||||
|         param_value = collect_specified_data[param_name] | ||||
|         expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']" | ||||
|         expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \ | ||||
|                        f'bug got {type(param_value).__name__}.' | ||||
| 
 | ||||
|         assert expected_msg == str(exc.value) | ||||
| 
 | ||||
|     def test_params_with_collect_specified_data_unexpected_key(self): | ||||
|         """Test the collect_specified_data parameter with unexpected key.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         data = {'unexpected_key': True} | ||||
|         with pytest.raises(ValueError) as exc: | ||||
|             SummaryCollector(summary_dir, collect_specified_data=data) | ||||
|         expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported." | ||||
|         assert expected_msg == str(exc.value) | ||||
| 
 | ||||
|     @pytest.mark.parametrize("custom_lineage_data", [ | ||||
|         123, | ||||
|         { | ||||
|             'custom': {} | ||||
|         }, | ||||
|         { | ||||
|             'custom': None | ||||
|         }, | ||||
|         { | ||||
|             123: 'custom' | ||||
|         } | ||||
|     ]) | ||||
|     def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data): | ||||
|         """Test the custom lineage data parameter type error.""" | ||||
|         summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) | ||||
|         with pytest.raises(TypeError) as exc: | ||||
|             SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data) | ||||
| 
 | ||||
|         if not isinstance(custom_lineage_data, dict): | ||||
|             expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \ | ||||
|                            f"bug got {type(custom_lineage_data).__name__}." | ||||
|         else: | ||||
|             param_name = list(custom_lineage_data)[0] | ||||
|             param_value = custom_lineage_data[param_name] | ||||
|             if not isinstance(param_name, str): | ||||
|                 arg_name = f'custom_lineage_data -> {param_name}' | ||||
|                 expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \ | ||||
|                                f'bug got {type(param_name).__name__}.' | ||||
|             else: | ||||
|                 arg_name = f'the value of custom_lineage_data -> {param_name}' | ||||
|                 expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \ | ||||
|                                f'bug got {type(param_value).__name__}.' | ||||
| 
 | ||||
|         assert expected_msg == str(exc.value) | ||||
					Loading…
					
					
				
		Reference in new issue