style changes and nonfunctional modifies

pull/713/head
李鸿章 5 years ago
parent 2860fd9338
commit 2ac60a1ad4

@ -14,9 +14,11 @@
# ============================================================================
"""Writes events to disk in a logdir."""
import os
import time
import stat
import time
from mindspore import log as logger
from ..._c_expression import EventWriter_
from ._summary_adapter import package_init_event
@ -28,6 +30,7 @@ class _WrapEventWriter(EventWriter_):
Args:
full_file_name (str): Include directory and file name.
"""
def __init__(self, full_file_name):
if full_file_name is not None:
EventWriter_.__init__(self, full_file_name)
@ -41,6 +44,7 @@ class EventRecord:
full_file_name (str): Summary event file path and file name.
flush_time (int): The flush seconds to flush the pending events to disk. Default: 120.
"""
def __init__(self, full_file_name: str, flush_time: int = 120):
self.full_file_name = full_file_name

@ -13,17 +13,17 @@
# limitations under the License.
# ============================================================================
"""Generate the summary event which conform to proto format."""
import time
import socket
import math
from enum import Enum, unique
import time
import numpy as np
from PIL import Image
from mindspore import log as logger
from ..summary_pb2 import Event
from ..anf_ir_pb2 import ModelProto, DataType
from ..._checkparam import _check_str_by_regular
from ..anf_ir_pb2 import DataType, ModelProto
from ..summary_pb2 import Event
# define the MindSpore image format
MS_IMAGE_TENSOR_FORMAT = 'NCHW'
@ -37,6 +37,7 @@ EVENT_FILE_INIT_VERSION = 1
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
g_summary_data_dict = {}
def save_summary_data(data_id, data):
"""Save the global summary cache."""
global g_summary_data_dict
@ -49,8 +50,8 @@ def del_summary_data(data_id):
if data_id in g_summary_data_dict:
del g_summary_data_dict[data_id]
else:
logger.warning("Can't del the data because data_id(%r) "
"does not have data in g_summary_data_dict", data_id)
logger.warning("Can't del the data because data_id(%r) " "does not have data in g_summary_data_dict", data_id)
def get_summary_data(data_id):
"""Save the global summary cache."""
@ -62,26 +63,6 @@ def get_summary_data(data_id):
logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id)
return ret
@unique
class SummaryType(Enum):
"""
Summary type.
Args:
SCALAR (Number): Summary Scalar enum.
TENSOR (Number): Summary TENSOR enum.
IMAGE (Number): Summary image enum.
GRAPH (Number): Summary graph enum.
HISTOGRAM (Number): Summary histogram enum.
INVALID (Number): Unknow type.
"""
SCALAR = 1 # Scalar summary
TENSOR = 2 # Tensor summary
IMAGE = 3 # Image summary
GRAPH = 4 # graph
HISTOGRAM = 5 # Histogram Summary
INVALID = 0xFF # unknow type
def get_event_file_name(prefix, suffix):
"""
@ -156,43 +137,34 @@ def package_summary_event(data_id, step):
# create the event of summary
summary_event = Event()
summary = summary_event.summary
summary_event.wall_time = time.time()
summary_event.step = int(step)
for value in data_list:
tag = value["name"]
summary_type = value["_type"]
data = value["data"]
summary_type = value["type"]
tag = value["name"]
logger.debug("Now process %r summary, tag = %r", summary_type, tag)
summary_value = summary.value.add()
summary_value.tag = tag
# get the summary type and parse the tag
if summary_type is SummaryType.SCALAR:
logger.debug("Now process Scalar summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
if summary_type == 'Scalar':
summary_value.scalar_value = _get_scalar_summary(tag, data)
elif summary_type is SummaryType.TENSOR:
logger.debug("Now process Tensor summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
elif summary_type == 'Tensor':
summary_tensor = summary_value.tensor
_get_tensor_summary(tag, data, summary_tensor)
elif summary_type is SummaryType.IMAGE:
logger.debug("Now process Image summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
elif summary_type == 'Image':
summary_image = summary_value.image
_get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT)
elif summary_type is SummaryType.HISTOGRAM:
logger.debug("Now process Histogram summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
elif summary_type == 'Histogram':
summary_histogram = summary_value.histogram
_fill_histogram_summary(tag, data, summary_histogram)
else:
# The data is invalid ,jump the data
logger.error("Summary type is error, tag = %r", tag)
continue
logger.error("Summary type(%r) is error, tag = %r", summary_type, tag)
summary_event.wall_time = time.time()
summary_event.step = int(step)
return summary_event
@ -255,11 +227,11 @@ def _get_scalar_summary(tag: str, np_value):
# So consider the dim = 1, shape = (1,) tensor is scalar
scalar_value = np_value[0]
if np_value.shape != (1,):
logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value)
logger.error("The tensor is not Scalar, tag = %r, Shape = %r", tag, np_value.shape)
else:
np_list = np_value.reshape(-1).tolist()
scalar_value = np_list[0]
logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value)
logger.error("The value is not Scalar, tag = %r, ndim = %r", tag, np_value.ndim)
logger.debug("The tag(%r) value is: %r", tag, scalar_value)
return scalar_value
@ -307,8 +279,7 @@ def _calc_histogram_bins(count):
Returns:
int, number of histogram bins.
"""
number_per_bucket = 10
max_bins = 90
max_bins, max_per_bin = 90, 10
if not count:
return 1
@ -318,7 +289,7 @@ def _calc_histogram_bins(count):
return 3
if count <= 880:
# note that math.ceil(881/10) + 1 equals 90
return int(math.ceil(count / number_per_bucket) + 1)
return count // max_per_bin + 1
return max_bins
@ -407,7 +378,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
"""
logger.debug("Set(%r) the image summary value", tag)
if np_value.ndim != 4:
logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value)
logger.error("The value is not Image, tag = %r, ndim = %r", tag, np_value.ndim)
# convert the tensor format
tensor = _convert_image_format(np_value, input_format)
@ -469,8 +440,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'):
"""
out_tensor = None
if np_tensor.ndim != len(input_format):
logger.error("The tensor(%r) can't convert the format(%r) because dim not same",
np_tensor, input_format)
logger.error("The tensor with dim(%r) can't convert the format(%r) because dim not same", np_tensor.ndim,
input_format)
return out_tensor
input_format = input_format.upper()
@ -512,7 +483,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8):
# check the tensor format
if tensor.ndim != 4 or tensor.shape[1] != 3:
logger.error("The image tensor(%r) is not 'NCHW' format", tensor)
logger.error("The image tensor with ndim(%r) and shape(%r) is not 'NCHW' format", tensor.ndim, tensor.shape)
return out_canvas
# expand the N

@ -14,18 +14,13 @@
# ============================================================================
"""Schedule the event writer process."""
import multiprocessing as mp
import re
from enum import Enum, unique
from mindspore import log as logger
from ..._c_expression import Tensor
from ._summary_adapter import SummaryType, package_summary_event, save_summary_data
# define the type of summary
FORMAT_SCALAR_STR = "Scalar"
FORMAT_TENSOR_STR = "Tensor"
FORMAT_IMAGE_STR = "Image"
FORMAT_HISTOGRAM_STR = "Histogram"
FORMAT_BEGIN_SLICE = "[:"
FORMAT_END_SLICE = "]"
from ..._c_expression import Tensor
from ._summary_adapter import package_summary_event, save_summary_data
# cache the summary data dict
# {id: SummaryData}
@ -40,73 +35,22 @@ g_summary_file = {}
@unique
class ScheduleMethod(Enum):
"""Schedule method type."""
FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue
TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy)
CACHE_DATA = 2 # Cache data util have idle worker to process it
FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue
TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy)
CACHE_DATA = 2 # Cache data util have idle worker to process it
@unique
class WorkerStatus(Enum):
"""Worker status."""
WORKER_INIT = 0 # data is exist but not process
WORKER_INIT = 0 # data is exist but not process
WORKER_PROCESSING = 1 # data is processing
WORKER_PROCESSED = 2 # data already processed
def _parse_tag_format(tag: str):
"""
Parse the tag.
Args:
tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor].
Returns:
Tuple, (SummaryType, summary_tag).
"""
summary_type = SummaryType.INVALID
summary_tag = tag
if tag is None:
logger.error("The tag is None")
return summary_type, summary_tag
# search the slice
slice_begin = FORMAT_BEGIN_SLICE
slice_end = FORMAT_END_SLICE
index = tag.rfind(slice_begin)
if index is -1:
logger.error("The tag(%s) have not the key slice.", tag)
return summary_type, summary_tag
# slice the tag
summary_tag = tag[:index]
# check the slice end
if tag[-1:] != slice_end:
logger.error("The tag(%s) end format is error", tag)
return summary_type, summary_tag
# check the type
type_str = tag[index + 2: -1]
logger.debug("The summary_tag is = %r", summary_tag)
logger.debug("The type_str value is = %r", type_str)
if type_str == FORMAT_SCALAR_STR:
summary_type = SummaryType.SCALAR
elif type_str == FORMAT_TENSOR_STR:
summary_type = SummaryType.TENSOR
elif type_str == FORMAT_IMAGE_STR:
summary_type = SummaryType.IMAGE
elif type_str == FORMAT_HISTOGRAM_STR:
summary_type = SummaryType.HISTOGRAM
else:
logger.error("The tag(%s) type is invalid.", tag)
summary_type = SummaryType.INVALID
return summary_type, summary_tag
WORKER_PROCESSED = 2 # data already processed
class SummaryDataManager:
"""Manage the summary global data cache."""
def __init__(self):
global g_summary_data_dict
self.size = len(g_summary_data_dict)
@ -144,6 +88,7 @@ class WorkerScheduler:
Args:
writer_id (int): The index of writer.
"""
def __init__(self, writer_id):
# Create the process of write event file
self.write_lock = mp.Lock()
@ -166,8 +111,8 @@ class WorkerScheduler:
bool, run successfully or not.
"""
# save the data to global cache , convert the tensor to numpy
result, size, data = self._data_convert(data)
if result is False:
result = self._data_convert(data)
if result is None:
logger.error("The step(%r) summary data(%r) is invalid.", step, size)
return False
@ -201,33 +146,47 @@ class WorkerScheduler:
self._update_scheduler()
return True
def _data_convert(self, data_list):
def _data_convert(self, summary):
"""Convert the data."""
if data_list is None:
if summary is None:
logger.warning("The step does not have record data.")
return False, 0, None
return None
# convert the summary to numpy
size = 0
for v_dict in data_list:
tag = v_dict["name"]
result = []
for v_dict in summary:
name = v_dict["name"]
data = v_dict["data"]
# confirm the data is valid
summary_type, summary_tag = _parse_tag_format(tag)
if summary_type == SummaryType.INVALID:
logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
return False, 0, None
summary_tag, summary_type = self._parse_from(name)
if summary_tag is None:
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
return None
if isinstance(data, Tensor):
# get the summary type and parse the tag
v_dict["name"] = summary_tag
v_dict["type"] = summary_type
v_dict["data"] = data.asnumpy()
size += v_dict["data"].size
result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type})
else:
logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
return False, 0, None
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
return None
return result
def _parse_from(self, name: str = None):
"""
Parse the tag and type from name.
return True, size, data_list
Args:
name (str): Format: TAG[:TYPE].
Returns:
Tuple, (summary_tag, summary_type).
"""
if name is None:
logger.error("The name is None")
return None, None
match = re.match(r'(.+)\[:(.+)\]', name)
if match:
return match.groups()
return None, None
def _update_scheduler(self):
"""Check the worker status and update schedule table."""
@ -261,6 +220,7 @@ class SummaryDataProcess(mp.Process):
write_lock (Lock): The process lock for writer same file.
writer_id (int): The index of writer.
"""
def __init__(self, step, data_id, write_lock, writer_id):
super(SummaryDataProcess, self).__init__()
self.daemon = True

@ -15,16 +15,20 @@
"""Record the summary event."""
import os
import threading
from mindspore import log as logger
from ._summary_scheduler import WorkerScheduler, SummaryDataManager
from ._summary_adapter import get_event_file_name, package_graph_event
from ._event_writer import EventRecord
from .._utils import _make_directory
from ..._checkparam import _check_str_by_regular
from .._utils import _make_directory
from ._event_writer import EventRecord
from ._summary_adapter import get_event_file_name, package_graph_event
from ._summary_scheduler import SummaryDataManager, WorkerScheduler
# for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening.
_summary_lock = threading.Lock()
# cache the summary data
_summary_tensor_cache = {}
_summary_lock = threading.Lock()
def _cache_summary_tensor_data(summary):
@ -34,14 +38,12 @@ def _cache_summary_tensor_data(summary):
Args:
summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
"""
_summary_lock.acquire()
if "SummaryRecord" in _summary_tensor_cache:
for record in summary:
_summary_tensor_cache["SummaryRecord"].append(record)
else:
_summary_tensor_cache["SummaryRecord"] = summary
_summary_lock.release()
return True
with _summary_lock:
if "SummaryRecord" in _summary_tensor_cache:
_summary_tensor_cache["SummaryRecord"].extend(summary)
else:
_summary_tensor_cache["SummaryRecord"] = summary
return True
class SummaryRecord:
@ -71,6 +73,7 @@ class SummaryRecord:
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
>>> file_prefix="xxx_", file_suffix="_yyy")
"""
def __init__(self,
log_dir,
queue_max_size=0,
@ -101,27 +104,21 @@ class SummaryRecord:
self.prefix = file_prefix
self.suffix = file_suffix
self.network = network
self.has_graph = False
self._closed = False
self.step = 0
# create the summary writer file
self.event_file_name = get_event_file_name(self.prefix, self.suffix)
if self.log_path[-1:] == '/':
self.full_file_name = self.log_path + self.event_file_name
else:
self.full_file_name = self.log_path + '/' + self.event_file_name
try:
self.full_file_name = os.path.realpath(self.full_file_name)
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
except Exception as ex:
raise RuntimeError(ex)
self.event_writer = EventRecord(self.full_file_name, self.flush_time)
self.writer_id = SummaryDataManager.summary_file_set(self.event_writer)
self.worker_scheduler = WorkerScheduler(self.writer_id)
self.step = 0
self._closed = False
self.network = network
self.has_graph = False
def record(self, step, train_network=None):
"""
Record the summary.
@ -147,15 +144,14 @@ class SummaryRecord:
# Set the current summary of train step
self.step = step
if self.network is not None and self.has_graph is False:
if self.network is not None and not self.has_graph:
graph_proto = self.network.get_func_graph_proto()
if graph_proto is None and train_network is not None:
graph_proto = train_network.get_func_graph_proto()
if graph_proto is None:
logger.error("Failed to get proto for graph")
else:
self.event_writer.write_event_to_file(
package_graph_event(graph_proto).SerializeToString())
self.event_writer.write_event_to_file(package_graph_event(graph_proto).SerializeToString())
self.event_writer.flush()
self.has_graph = True
data = _summary_tensor_cache.get("SummaryRecord")

Loading…
Cancel
Save