From 69d3abfdd3df8dc5c5172cd4b2d32f04b27bac50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=B8=BF=E7=AB=A0?= Date: Thu, 30 Apr 2020 17:31:26 +0800 Subject: [PATCH] reduce dead step(step % flush_step > 0) summary --- mindspore/train/summary/summary_record.py | 25 ++++++++--------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 356a6dfc21..43baebccf9 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -40,18 +40,17 @@ def _cache_summary_tensor_data(summary): summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]. """ with _summary_lock: - if "SummaryRecord" in _summary_tensor_cache: - _summary_tensor_cache["SummaryRecord"].extend(summary) - else: - _summary_tensor_cache["SummaryRecord"] = summary + for item in summary: + _summary_tensor_cache[item['name']] = item['data'] return True def _get_summary_tensor_data(): - if 'SummaryRecord' not in _summary_tensor_cache: - return None + global _summary_tensor_cache with _summary_lock: - return _summary_tensor_cache.pop('SummaryRecord') + data = _summary_tensor_cache + _summary_tensor_cache = {} + return data class SummaryRecord: @@ -158,11 +157,11 @@ class SummaryRecord: else: self.event_writer.write(package_graph_event(graph_proto).SerializeToString()) self.has_graph = True - if _summary_tensor_cache.get('SummaryRecord') is None: + if not _summary_tensor_cache: return True data = _get_summary_tensor_data() - if data is None: + if not data: logger.error("The step(%r) does not have record data.", step) return False if self.queue_max_size > 0 and len(data) > self.queue_max_size: @@ -225,15 +224,9 @@ class SummaryRecord: def _data_convert(self, summary): """Convert the data.""" - if summary is None: - logger.warning("The step does not have record data.") - return None - # convert the summary to numpy result = [] - for v_dict in summary: - name = v_dict["name"] - data = v_dict["data"] + for name, data in summary.items(): # confirm the data is valid summary_tag, summary_type = SummaryRecord._parse_from(name) if summary_tag is None: