|  |  | @ -17,6 +17,7 @@ import atexit | 
			
		
	
		
		
			
				
					
					|  |  |  | import os |  |  |  | import os | 
			
		
	
		
		
			
				
					
					|  |  |  | import re |  |  |  | import re | 
			
		
	
		
		
			
				
					
					|  |  |  | import threading |  |  |  | import threading | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | import time | 
			
		
	
		
		
			
				
					
					|  |  |  | from collections import defaultdict |  |  |  | from collections import defaultdict | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | from mindspore import log as logger |  |  |  | from mindspore import log as logger | 
			
		
	
	
		
		
			
				
					|  |  | @ -24,7 +25,7 @@ from mindspore.nn import Cell | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | from ..._c_expression import Tensor |  |  |  | from ..._c_expression import Tensor | 
			
		
	
		
		
			
				
					
					|  |  |  | from ..._checkparam import Validator |  |  |  | from ..._checkparam import Validator | 
			
		
	
		
		
			
				
					
					|  |  |  | from .._utils import _check_lineage_value, _check_to_numpy, _make_directory |  |  |  | from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  | from ._summary_adapter import get_event_file_name, package_graph_event |  |  |  | from ._summary_adapter import get_event_file_name, package_graph_event | 
			
		
	
		
		
			
				
					
					|  |  |  | from ._explain_adapter import check_explain_proto |  |  |  | from ._explain_adapter import check_explain_proto | 
			
		
	
		
		
			
				
					
					|  |  |  | from ._writer_pool import WriterPool |  |  |  | from ._writer_pool import WriterPool | 
			
		
	
	
		
		
			
				
					|  |  | @ -34,6 +35,9 @@ from ._writer_pool import WriterPool | 
			
		
	
		
		
			
				
					
					|  |  |  | _summary_lock = threading.Lock() |  |  |  | _summary_lock = threading.Lock() | 
			
		
	
		
		
			
				
					
					|  |  |  | # cache the summary data |  |  |  | # cache the summary data | 
			
		
	
		
		
			
				
					
					|  |  |  | _summary_tensor_cache = {} |  |  |  | _summary_tensor_cache = {} | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | _DEFAULT_EXPORT_OPTIONS = { | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     'tensor_format': 'npy', | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | } | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | def _cache_summary_tensor_data(summary): |  |  |  | def _cache_summary_tensor_data(summary): | 
			
		
	
	
		
		
			
				
					|  |  | @ -57,6 +61,27 @@ def _get_summary_tensor_data(): | 
			
		
	
		
		
			
				
					
					|  |  |  |         return data |  |  |  |         return data | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | def process_export_options(export_options): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     """Check specified data type and value.""" | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     if export_options is None: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         return None | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     check_value_type('export_options', export_options, [dict, type(None)]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     for param_name in export_options: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         check_value_type(param_name, param_name, [str]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     unexpected_params = set(export_options) - set(_DEFAULT_EXPORT_OPTIONS) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     if unexpected_params: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         raise ValueError(f'For `export_options` the keys {unexpected_params} are unsupported, ' | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                          f'expect the follow keys: {list(_DEFAULT_EXPORT_OPTIONS.keys())}') | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     for item in set(export_options): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         check_value_type(item, export_options.get(item), [str, type(None)]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     return export_options | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | class SummaryRecord: |  |  |  | class SummaryRecord: | 
			
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  |     SummaryRecord is used to record the summary data and lineage data. |  |  |  |     SummaryRecord is used to record the summary data and lineage data. | 
			
		
	
	
		
		
			
				
					|  |  | @ -81,6 +106,13 @@ class SummaryRecord: | 
			
		
	
		
		
			
				
					
					|  |  |  |             Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. |  |  |  |             Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`. | 
			
		
	
		
		
			
				
					
					|  |  |  |         raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs |  |  |  |         raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs | 
			
		
	
		
		
			
				
					
					|  |  |  |             in recording data. Default: False, this means that error logs are printed and no exception is thrown. |  |  |  |             in recording data. Default: False, this means that error logs are printed and no exception is thrown. | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         export_options (Union[None, dict]): Perform custom operations on the export data. | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             Default: None, it means there is no export data. | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             You can customize the export data with a dictionary. For example, you can set {'tensor_format': 'npy'} | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             to export tensor as npy file. The data that supports control is shown below. | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             - tensor_format (Union[str, None]): Customize the export tensor format. | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |               Default: None, it means there is no export tensor. | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     Raises: |  |  |  |     Raises: | 
			
		
	
		
		
			
				
					
					|  |  |  |         TypeError: If the parameter type is incorrect. |  |  |  |         TypeError: If the parameter type is incorrect. | 
			
		
	
	
		
		
			
				
					|  |  | @ -99,7 +131,7 @@ class SummaryRecord: | 
			
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", |  |  |  |     def __init__(self, log_dir, file_prefix="events", file_suffix="_MS", | 
			
		
	
		
		
			
				
					
					|  |  |  |                  network=None, max_file_size=None, raise_exception=False): |  |  |  |                  network=None, max_file_size=None, raise_exception=False, export_options=None): | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         self._closed, self._event_writer = False, None |  |  |  |         self._closed, self._event_writer = False, None | 
			
		
	
		
		
			
				
					
					|  |  |  |         self._mode, self._data_pool = 'train', defaultdict(list) |  |  |  |         self._mode, self._data_pool = 'train', defaultdict(list) | 
			
		
	
	
		
		
			
				
					|  |  | @ -126,13 +158,20 @@ class SummaryRecord: | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.network = network |  |  |  |         self.network = network | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.has_graph = False |  |  |  |         self.has_graph = False | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         seconds = str(int(time.time())) | 
			
		
	
		
		
			
				
					
					|  |  |  |         # create the summary writer file |  |  |  |         # create the summary writer file | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.event_file_name = get_event_file_name(self.prefix, self.suffix) |  |  |  |         self.event_file_name = get_event_file_name(self.prefix, self.suffix, seconds) | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |         self.full_file_name = os.path.join(self.log_path, self.event_file_name) |  |  |  |         self.full_file_name = os.path.join(self.log_path, self.event_file_name) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self._export_options = process_export_options(export_options) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         export_dir = '' | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         if self._export_options is not None: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             export_dir = "export_{}".format(seconds) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         filename_dict = dict(summary=self.full_file_name, |  |  |  |         filename_dict = dict(summary=self.full_file_name, | 
			
		
	
		
		
			
				
					
					|  |  |  |                              lineage=get_event_file_name(self.prefix, '_lineage'), |  |  |  |                              lineage=get_event_file_name(self.prefix, '_lineage'), | 
			
		
	
		
		
			
				
					
					|  |  |  |                              explainer=get_event_file_name(self.prefix, '_explain')) |  |  |  |                              explainer=get_event_file_name(self.prefix, '_explain'), | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                              exporter=export_dir) | 
			
		
	
		
		
			
				
					
					|  |  |  |         self._event_writer = WriterPool(log_dir, |  |  |  |         self._event_writer = WriterPool(log_dir, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         max_file_size, |  |  |  |                                         max_file_size, | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         raise_exception, |  |  |  |                                         raise_exception, | 
			
		
	
	
		
		
			
				
					|  |  | @ -211,7 +250,11 @@ class SummaryRecord: | 
			
		
	
		
		
			
				
					
					|  |  |  |             if name in {item['tag'] for item in self._data_pool[plugin]}: |  |  |  |             if name in {item['tag'] for item in self._data_pool[plugin]}: | 
			
		
	
		
		
			
				
					
					|  |  |  |                 entry = repr(f'{name}/{plugin}') |  |  |  |                 entry = repr(f'{name}/{plugin}') | 
			
		
	
		
		
			
				
					
					|  |  |  |                 logger.warning(f'{entry} has duplicate values. Only the newest one will be recorded.') |  |  |  |                 logger.warning(f'{entry} has duplicate values. Only the newest one will be recorded.') | 
			
		
	
		
		
			
				
					
					|  |  |  |             self._data_pool[plugin].append(dict(tag=name, value=np_value)) |  |  |  |             data = dict(tag=name, value=np_value) | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             export_plugin = '{}_format'.format(plugin) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             if self._export_options is not None and export_plugin in self._export_options: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                 data['export_option'] = self._export_options.get(export_plugin) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             self._data_pool[plugin].append(data) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'): |  |  |  |         elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'): | 
			
		
	
		
		
			
				
					
					|  |  |  |             _check_lineage_value(plugin, value) |  |  |  |             _check_lineage_value(plugin, value) | 
			
		
	
	
		
		
			
				
					|  |  | 
 |