|
|
|
|
@ -14,22 +14,17 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Record the summary event."""
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import threading
|
|
|
|
|
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
|
|
from ..._c_expression import Tensor
|
|
|
|
|
from ..._checkparam import _check_str_by_regular
|
|
|
|
|
from ._summary_scheduler import WorkerScheduler, SummaryDataManager
|
|
|
|
|
from ._summary_adapter import get_event_file_name, package_graph_event
|
|
|
|
|
from ._event_writer import EventRecord
|
|
|
|
|
from .._utils import _make_directory
|
|
|
|
|
from ._event_writer import EventWriter
|
|
|
|
|
from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event
|
|
|
|
|
from ..._checkparam import _check_str_by_regular
|
|
|
|
|
|
|
|
|
|
# for the moment, this lock is for caution's sake,
|
|
|
|
|
# there are actually no any concurrencies happening.
|
|
|
|
|
_summary_lock = threading.Lock()
|
|
|
|
|
# cache the summary data
|
|
|
|
|
_summary_tensor_cache = {}
|
|
|
|
|
_summary_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cache_summary_tensor_data(summary):
|
|
|
|
|
@ -39,18 +34,14 @@ def _cache_summary_tensor_data(summary):
|
|
|
|
|
Args:
|
|
|
|
|
summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
|
|
|
|
|
"""
|
|
|
|
|
with _summary_lock:
|
|
|
|
|
for item in summary:
|
|
|
|
|
_summary_tensor_cache[item['name']] = item['data']
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_summary_tensor_data():
|
|
|
|
|
global _summary_tensor_cache
|
|
|
|
|
with _summary_lock:
|
|
|
|
|
data = _summary_tensor_cache
|
|
|
|
|
_summary_tensor_cache = {}
|
|
|
|
|
return data
|
|
|
|
|
_summary_lock.acquire()
|
|
|
|
|
if "SummaryRecord" in _summary_tensor_cache:
|
|
|
|
|
for record in summary:
|
|
|
|
|
_summary_tensor_cache["SummaryRecord"].append(record)
|
|
|
|
|
else:
|
|
|
|
|
_summary_tensor_cache["SummaryRecord"] = summary
|
|
|
|
|
_summary_lock.release()
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SummaryRecord:
|
|
|
|
|
@ -80,7 +71,6 @@ class SummaryRecord:
|
|
|
|
|
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
|
|
|
|
|
>>> file_prefix="xxx_", file_suffix="_yyy")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
log_dir,
|
|
|
|
|
queue_max_size=0,
|
|
|
|
|
@ -111,18 +101,26 @@ class SummaryRecord:
|
|
|
|
|
|
|
|
|
|
self.prefix = file_prefix
|
|
|
|
|
self.suffix = file_suffix
|
|
|
|
|
self.network = network
|
|
|
|
|
self.has_graph = False
|
|
|
|
|
self._closed = False
|
|
|
|
|
|
|
|
|
|
# create the summary writer file
|
|
|
|
|
self.event_file_name = get_event_file_name(self.prefix, self.suffix)
|
|
|
|
|
if self.log_path[-1:] == '/':
|
|
|
|
|
self.full_file_name = self.log_path + self.event_file_name
|
|
|
|
|
else:
|
|
|
|
|
self.full_file_name = self.log_path + '/' + self.event_file_name
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
|
|
|
|
|
self.full_file_name = os.path.realpath(self.full_file_name)
|
|
|
|
|
except Exception as ex:
|
|
|
|
|
raise RuntimeError(ex)
|
|
|
|
|
self.event_writer = EventWriter(self.full_file_name, self.flush_time)
|
|
|
|
|
self.event_writer.write(package_init_event().SerializeToString())
|
|
|
|
|
self.event_writer = EventRecord(self.full_file_name, self.flush_time)
|
|
|
|
|
self.writer_id = SummaryDataManager.summary_file_set(self.event_writer)
|
|
|
|
|
self.worker_scheduler = WorkerScheduler(self.writer_id)
|
|
|
|
|
|
|
|
|
|
self.step = 0
|
|
|
|
|
self._closed = False
|
|
|
|
|
self.network = network
|
|
|
|
|
self.has_graph = False
|
|
|
|
|
|
|
|
|
|
def record(self, step, train_network=None):
|
|
|
|
|
"""
|
|
|
|
|
@ -147,34 +145,42 @@ class SummaryRecord:
|
|
|
|
|
if not isinstance(step, int) or isinstance(step, bool):
|
|
|
|
|
raise ValueError("`step` should be int")
|
|
|
|
|
# Set the current summary of train step
|
|
|
|
|
self.step = step
|
|
|
|
|
|
|
|
|
|
if self.network is not None and not self.has_graph:
|
|
|
|
|
if self.network is not None and self.has_graph is False:
|
|
|
|
|
graph_proto = self.network.get_func_graph_proto()
|
|
|
|
|
if graph_proto is None and train_network is not None:
|
|
|
|
|
graph_proto = train_network.get_func_graph_proto()
|
|
|
|
|
if graph_proto is None:
|
|
|
|
|
logger.error("Failed to get proto for graph")
|
|
|
|
|
else:
|
|
|
|
|
self.event_writer.write(package_graph_event(graph_proto).SerializeToString())
|
|
|
|
|
self.event_writer.write_event_to_file(
|
|
|
|
|
package_graph_event(graph_proto).SerializeToString())
|
|
|
|
|
self.event_writer.flush()
|
|
|
|
|
self.has_graph = True
|
|
|
|
|
if not _summary_tensor_cache:
|
|
|
|
|
data = _summary_tensor_cache.get("SummaryRecord")
|
|
|
|
|
if data is None:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
data = _get_summary_tensor_data()
|
|
|
|
|
if not data:
|
|
|
|
|
logger.error("The step(%r) does not have record data.", step)
|
|
|
|
|
data = _summary_tensor_cache.get("SummaryRecord")
|
|
|
|
|
if data is None:
|
|
|
|
|
logger.error("The step(%r) does not have record data.", self.step)
|
|
|
|
|
return False
|
|
|
|
|
if self.queue_max_size > 0 and len(data) > self.queue_max_size:
|
|
|
|
|
logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data),
|
|
|
|
|
self.queue_max_size)
|
|
|
|
|
|
|
|
|
|
# clean the data of cache
|
|
|
|
|
del _summary_tensor_cache["SummaryRecord"]
|
|
|
|
|
|
|
|
|
|
# process the data
|
|
|
|
|
result = self._data_convert(data)
|
|
|
|
|
if not result:
|
|
|
|
|
logger.error("The step(%r) summary data is invalid.", step)
|
|
|
|
|
return False
|
|
|
|
|
self.event_writer.write((result, step))
|
|
|
|
|
logger.debug("Send the summary data to scheduler for saving, step = %d", step)
|
|
|
|
|
self.worker_scheduler.dispatch(self.step, data)
|
|
|
|
|
|
|
|
|
|
# count & flush
|
|
|
|
|
self.event_writer.count_event()
|
|
|
|
|
self.event_writer.flush_cycle()
|
|
|
|
|
|
|
|
|
|
logger.debug("Send the summary data to scheduler for saving, step = %d", self.step)
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
@ -190,7 +196,7 @@ class SummaryRecord:
|
|
|
|
|
Returns:
|
|
|
|
|
String, the full path of log file.
|
|
|
|
|
"""
|
|
|
|
|
return self.full_file_name
|
|
|
|
|
return self.event_writer.full_file_name
|
|
|
|
|
|
|
|
|
|
def flush(self):
|
|
|
|
|
"""
|
|
|
|
|
@ -218,44 +224,20 @@ class SummaryRecord:
|
|
|
|
|
>>> summary_record.close()
|
|
|
|
|
"""
|
|
|
|
|
if not self._closed:
|
|
|
|
|
self._check_data_before_close()
|
|
|
|
|
self.worker_scheduler.close()
|
|
|
|
|
# event writer flush and close
|
|
|
|
|
self.event_writer.close()
|
|
|
|
|
self._closed = True
|
|
|
|
|
|
|
|
|
|
def _data_convert(self, summary):
|
|
|
|
|
"""Convert the data."""
|
|
|
|
|
# convert the summary to numpy
|
|
|
|
|
result = []
|
|
|
|
|
for name, data in summary.items():
|
|
|
|
|
# confirm the data is valid
|
|
|
|
|
summary_tag, summary_type = SummaryRecord._parse_from(name)
|
|
|
|
|
if summary_tag is None:
|
|
|
|
|
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
|
|
|
|
|
return None
|
|
|
|
|
if isinstance(data, Tensor):
|
|
|
|
|
result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type})
|
|
|
|
|
else:
|
|
|
|
|
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _parse_from(name: str = None):
|
|
|
|
|
"""
|
|
|
|
|
Parse the tag and type from name.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): Format: TAG[:TYPE].
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple, (summary_tag, summary_type).
|
|
|
|
|
"""
|
|
|
|
|
if name is None:
|
|
|
|
|
logger.error("The name is None")
|
|
|
|
|
return None, None
|
|
|
|
|
match = re.match(r'(.+)\[:(.+)\]', name)
|
|
|
|
|
if match:
|
|
|
|
|
return match.groups()
|
|
|
|
|
logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name)
|
|
|
|
|
return None, None
|
|
|
|
|
def __del__(self):
|
|
|
|
|
"""Process exit is called."""
|
|
|
|
|
if hasattr(self, "worker_scheduler"):
|
|
|
|
|
if self.worker_scheduler:
|
|
|
|
|
self.close()
|
|
|
|
|
|
|
|
|
|
def _check_data_before_close(self):
|
|
|
|
|
"Check whether there is any data in the cache, and if so, call record"
|
|
|
|
|
data = _summary_tensor_cache.get("SummaryRecord")
|
|
|
|
|
if data is not None:
|
|
|
|
|
self.record(self.step)
|
|
|
|
|
|