!1069 Use a resident process to write summary files and SummaryRecord as context manager

Merge pull request !1069 from 李鸿章/context_manager
pull/1069/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6b68671805

@ -14,91 +14,74 @@
# ============================================================================ # ============================================================================
"""Writes events to disk in a logdir.""" """Writes events to disk in a logdir."""
import os import os
import time
import stat import stat
from mindspore import log as logger from collections import deque
from multiprocessing import Pool, Process, Queue, cpu_count
from ..._c_expression import EventWriter_ from ..._c_expression import EventWriter_
from ._summary_adapter import package_init_event from ._summary_adapter import package_summary_event
class _WrapEventWriter(EventWriter_): def _pack(result, step):
""" summary_event = package_summary_event(result, step)
Wrap the c++ EventWriter object. return summary_event.SerializeToString()
Args:
full_file_name (str): Include directory and file name.
"""
def __init__(self, full_file_name):
if full_file_name is not None:
EventWriter_.__init__(self, full_file_name)
class EventWriter(Process):
class EventRecord:
""" """
Creates a `EventFileWriter` and write event to file. Creates a `EventWriter` and write event to file.
Args: Args:
full_file_name (str): Summary event file path and file name. filepath (str): Summary event file path and file name.
flush_time (int): The flush seconds to flush the pending events to disk. Default: 120. flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120.
""" """
def __init__(self, full_file_name: str, flush_time: int = 120):
self.full_file_name = full_file_name
# The first event will be flushed immediately.
self.flush_time = flush_time
self.next_flush_time = 0
# create event write object
self.event_writer = self._create_event_file()
self._init_event_file()
# count the events
self.event_count = 0
def _create_event_file(self):
"""Create the event write file."""
with open(self.full_file_name, 'w'):
os.chmod(self.full_file_name, stat.S_IWUSR | stat.S_IRUSR)
# create c++ event write object
event_writer = _WrapEventWriter(self.full_file_name)
return event_writer
def _init_event_file(self):
"""Send the init event to file."""
self.event_writer.Write((package_init_event()).SerializeToString())
self.flush()
return True
def write_event_to_file(self, event_str):
"""Write the event to file."""
self.event_writer.Write(event_str)
def get_data_count(self):
"""Return the event count."""
return self.event_count
def flush_cycle(self):
"""Flush file by timer."""
self.event_count = self.event_count + 1
# Flush the event writer every so often.
now = int(time.time())
if now > self.next_flush_time:
self.flush()
# update the flush time
self.next_flush_time = now + self.flush_time
def count_event(self): def __init__(self, filepath: str, flush_interval: int) -> None:
"""Count event.""" super().__init__()
logger.debug("Write the event count is %r", self.event_count) with open(filepath, 'w'):
self.event_count = self.event_count + 1 os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR)
return self.event_count self._writer = EventWriter_(filepath)
self._queue = Queue(cpu_count() * 2)
self.start()
def run(self):
with Pool() as pool:
deq = deque()
while True:
while deq and deq[0].ready():
self._writer.Write(deq.popleft().get())
if not self._queue.empty():
action, data = self._queue.get()
if action == 'WRITE':
if not isinstance(data, (str, bytes)):
deq.append(pool.apply_async(_pack, data))
else:
self._writer.Write(data)
elif action == 'FLUSH':
self._writer.Flush()
elif action == 'END':
break
for res in deq:
self._writer.Write(res.get())
self._writer.Shut()
def write(self, data) -> None:
"""
Write the event to file.
Args:
data (Optional[str, Tuple[list, int]]): The data to write.
"""
self._queue.put(('WRITE', data))
def flush(self): def flush(self):
"""Flush the event file to disk.""" """Flush the writer."""
self.event_writer.Flush() self._queue.put(('FLUSH', None))
def close(self): def close(self) -> None:
"""Flush the event file to disk and close the file.""" """Close the writer."""
self.flush() self._queue.put(('END', None))
self.event_writer.Shut() self.join()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -53,14 +53,13 @@ def me_train_tensor(net, input_np, label_np, epoch_size=2):
_network = wrap.WithLossCell(net, loss) _network = wrap.WithLossCell(net, loss)
_train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt))
_train_net.set_train() _train_net.set_train()
summary_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) as summary_writer:
for epoch in range(0, epoch_size): for epoch in range(0, epoch_size):
print(f"epoch %d" % (epoch)) print(f"epoch %d" % (epoch))
output = _train_net(Tensor(input_np), Tensor(label_np)) output = _train_net(Tensor(input_np), Tensor(label_np))
summary_writer.record(i) summary_writer.record(i)
print("********output***********") print("********output***********")
print(output.asnumpy()) print(output.asnumpy())
summary_writer.close()
def me_infer_tensor(net, input_np): def me_infer_tensor(net, input_np):

@ -91,15 +91,14 @@ def train_summary_record_scalar_for_1(test_writer, steps, fwd_x, fwd_y):
def me_scalar_summary(steps, tag=None, value=None): def me_scalar_summary(steps, tag=None, value=None):
test_writer = SummaryRecord(SUMMARY_DIR_ME_TEMP) with SummaryRecord(SUMMARY_DIR_ME_TEMP) as test_writer:
x = Tensor(np.array([1.1]).astype(np.float32)) x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32)) y = Tensor(np.array([1.2]).astype(np.float32))
out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y) out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y)
test_writer.close() return out_me_dict
return out_me_dict
@pytest.mark.level0 @pytest.mark.level0

@ -106,18 +106,17 @@ def test_graph_summary_sample():
optim = Momentum(net.trainable_params(), 0.1, 0.9) optim = Momentum(net.trainable_params(), 0.1, 0.9)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
model.train(2, dataset) model.train(2, dataset)
# step 2: create the Event # step 2: create the Event
for i in range(1, 5): for i in range(1, 5):
test_writer.record(i) test_writer.record(i)
# step 3: send the event to mq # step 3: send the event to mq
# step 4: accept the event and write the file # step 4: accept the event and write the file
test_writer.close()
log.debug("finished test_graph_summary_sample") log.debug("finished test_graph_summary_sample")
def test_graph_summary_callback(): def test_graph_summary_callback():
@ -127,9 +126,9 @@ def test_graph_summary_callback():
optim = Momentum(net.trainable_params(), 0.1, 0.9) optim = Momentum(net.trainable_params(), 0.1, 0.9)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
summary_cb = SummaryStep(test_writer, 1) summary_cb = SummaryStep(test_writer, 1)
model.train(2, dataset, callbacks=summary_cb) model.train(2, dataset, callbacks=summary_cb)
def test_graph_summary_callback2(): def test_graph_summary_callback2():
@ -139,6 +138,6 @@ def test_graph_summary_callback2():
optim = Momentum(net.trainable_params(), 0.1, 0.9) optim = Momentum(net.trainable_params(), 0.1, 0.9)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer:
summary_cb = SummaryStep(test_writer, 1) summary_cb = SummaryStep(test_writer, 1)
model.train(2, dataset, callbacks=summary_cb) model.train(2, dataset, callbacks=summary_cb)

@ -52,12 +52,11 @@ def _wrap_test_data(input_data: Tensor):
def test_histogram_summary(): def test_histogram_summary():
"""Test histogram summary.""" """Test histogram summary."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]])) test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]]))
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(step=1) test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)
@ -68,20 +67,18 @@ def test_histogram_summary():
def test_histogram_multi_summary(): def test_histogram_multi_summary():
"""Test histogram multiple step.""" """Test histogram multiple step."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
rng = np.random.RandomState(10)
size = 50
num_step = 5
for i in range(num_step): rng = np.random.RandomState(10)
arr = rng.normal(size=size) size = 50
num_step = 5
test_data = _wrap_test_data(Tensor(arr)) for i in range(num_step):
_cache_summary_tensor_data(test_data) arr = rng.normal(size=size)
test_writer.record(step=i)
test_writer.close() test_data = _wrap_test_data(Tensor(arr))
_cache_summary_tensor_data(test_data)
test_writer.record(step=i)
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)
@ -93,12 +90,11 @@ def test_histogram_multi_summary():
def test_histogram_summary_scalar_tensor(): def test_histogram_summary_scalar_tensor():
"""Test histogram summary, input is a scalar tensor.""" """Test histogram summary, input is a scalar tensor."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
test_data = _wrap_test_data(Tensor(1)) test_data = _wrap_test_data(Tensor(1))
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(step=1) test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)
@ -109,12 +105,11 @@ def test_histogram_summary_scalar_tensor():
def test_histogram_summary_empty_tensor(): def test_histogram_summary_empty_tensor():
"""Test histogram summary, input is an empty tensor.""" """Test histogram summary, input is an empty tensor."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
test_data = _wrap_test_data(Tensor([])) test_data = _wrap_test_data(Tensor([]))
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(step=1) test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)
@ -125,15 +120,14 @@ def test_histogram_summary_empty_tensor():
def test_histogram_summary_same_value(): def test_histogram_summary_same_value():
"""Test histogram summary, input is an ones tensor.""" """Test histogram summary, input is an ones tensor."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
dim1 = 100 dim1 = 100
dim2 = 100 dim2 = 100
test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2]))) test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2])))
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(step=1) test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)
@ -146,15 +140,14 @@ def test_histogram_summary_same_value():
def test_histogram_summary_high_dims(): def test_histogram_summary_high_dims():
"""Test histogram summary, input is a 4-dimension tensor.""" """Test histogram summary, input is a 4-dimension tensor."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
dim = 10 dim = 10
rng = np.random.RandomState(0) rng = np.random.RandomState(0)
tensor_data = rng.normal(size=[dim, dim, dim, dim]) tensor_data = rng.normal(size=[dim, dim, dim, dim])
test_data = _wrap_test_data(Tensor(tensor_data)) test_data = _wrap_test_data(Tensor(tensor_data))
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(step=1) test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)
@ -167,20 +160,19 @@ def test_histogram_summary_high_dims():
def test_histogram_summary_nan_inf(): def test_histogram_summary_nan_inf():
"""Test histogram summary, input tensor has nan.""" """Test histogram summary, input tensor has nan."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
dim1 = 100 dim1 = 100
dim2 = 100 dim2 = 100
arr = np.ones([dim1, dim2]) arr = np.ones([dim1, dim2])
arr[0][0] = np.nan arr[0][0] = np.nan
arr[0][1] = np.inf arr[0][1] = np.inf
arr[0][2] = -np.inf arr[0][2] = -np.inf
test_data = _wrap_test_data(Tensor(arr)) test_data = _wrap_test_data(Tensor(arr))
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(step=1) test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)
@ -193,12 +185,11 @@ def test_histogram_summary_nan_inf():
def test_histogram_summary_all_nan_inf(): def test_histogram_summary_all_nan_inf():
"""Test histogram summary, input tensor has no valid number.""" """Test histogram summary, input tensor has no valid number."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf]))) test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf])))
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(step=1) test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name) file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name) reader = SummaryReader(file_name)

@ -74,23 +74,21 @@ def test_image_summary_sample():
""" test_image_summary_sample """ """ test_image_summary_sample """
log.debug("begin test_image_summary_sample") log.debug("begin test_image_summary_sample")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary # step 1: create the test data for summary
# step 2: create the Event # step 2: create the Event
for i in range(1, 5): for i in range(1, 5):
test_data = get_test_data(i) test_data = get_test_data(i)
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(i) test_writer.record(i)
test_writer.flush() test_writer.flush()
# step 3: send the event to mq # step 3: send the event to mq
# step 4: accept the event and write the file # step 4: accept the event and write the file
test_writer.close() log.debug("finished test_image_summary_sample")
log.debug("finished test_image_summary_sample")
class Net(nn.Cell): class Net(nn.Cell):
@ -174,23 +172,21 @@ def test_image_summary_train():
log.debug("begin test_image_summary_sample") log.debug("begin test_image_summary_sample")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event # step 1: create the test data for summary
model = get_model() # step 2: create the Event
fn = ImageSummaryCallback(test_writer)
summary_recode = SummaryStep(fn, 1)
model.train(2, dataset, callbacks=summary_recode)
# step 3: send the event to mq model = get_model()
fn = ImageSummaryCallback(test_writer)
summary_recode = SummaryStep(fn, 1)
model.train(2, dataset, callbacks=summary_recode)
# step 4: accept the event and write the file # step 3: send the event to mq
test_writer.close()
log.debug("finished test_image_summary_sample") # step 4: accept the event and write the file
log.debug("finished test_image_summary_sample")
def test_image_summary_data(): def test_image_summary_data():
@ -209,18 +205,12 @@ def test_image_summary_data():
log.debug("begin test_image_summary_sample") log.debug("begin test_image_summary_sample")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
# step 1: create the test data for summary
# step 2: create the Event
_cache_summary_tensor_data(test_data_list)
test_writer.record(1)
test_writer.flush()
# step 3: send the event to mq # step 1: create the test data for summary
# step 4: accept the event and write the file # step 2: create the Event
test_writer.close() _cache_summary_tensor_data(test_data_list)
test_writer.record(1)
log.debug("finished test_image_summary_sample") log.debug("finished test_image_summary_sample")

@ -65,22 +65,21 @@ def test_scalar_summary_sample():
""" test_scalar_summary_sample """ """ test_scalar_summary_sample """
log.debug("begin test_scalar_summary_sample") log.debug("begin test_scalar_summary_sample")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the test data for summary # step 1: create the test data for summary
# step 2: create the Event # step 2: create the Event
for i in range(1, 500): for i in range(1, 500):
test_data = get_test_data(i) test_data = get_test_data(i)
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(i) test_writer.record(i)
# step 3: send the event to mq # step 3: send the event to mq
# step 4: accept the event and write the file # step 4: accept the event and write the file
test_writer.close()
log.debug("finished test_scalar_summary_sample") log.debug("finished test_scalar_summary_sample")
def get_test_data_shape_1(step): def get_test_data_shape_1(step):
@ -110,22 +109,21 @@ def test_scalar_summary_sample_with_shape_1():
""" test_scalar_summary_sample_with_shape_1 """ """ test_scalar_summary_sample_with_shape_1 """
log.debug("begin test_scalar_summary_sample_with_shape_1") log.debug("begin test_scalar_summary_sample_with_shape_1")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the test data for summary # step 1: create the test data for summary
# step 2: create the Event # step 2: create the Event
for i in range(1, 100): for i in range(1, 100):
test_data = get_test_data_shape_1(i) test_data = get_test_data_shape_1(i)
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(i) test_writer.record(i)
# step 3: send the event to mq # step 3: send the event to mq
# step 4: accept the event and write the file # step 4: accept the event and write the file
test_writer.close()
log.debug("finished test_scalar_summary_sample") log.debug("finished test_scalar_summary_sample")
# Test: test with ge # Test: test with ge
@ -152,26 +150,24 @@ def test_scalar_summary_with_ge():
log.debug("begin test_scalar_summary_with_ge") log.debug("begin test_scalar_summary_with_ge")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the network for summary # step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32)) x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32)) y = Tensor(np.array([1.2]).astype(np.float32))
net = SummaryDemo() net = SummaryDemo()
net.set_train() net.set_train()
# step 2: create the Event # step 2: create the Event
steps = 100 steps = 100
for i in range(1, steps): for i in range(1, steps):
x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32))
y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32))
net(x, y) net(x, y)
test_writer.record(i) test_writer.record(i)
# step 3: close the writer
test_writer.close()
log.debug("finished test_scalar_summary_with_ge") log.debug("finished test_scalar_summary_with_ge")
# test the problem of two consecutive use cases going wrong # test the problem of two consecutive use cases going wrong
@ -180,55 +176,52 @@ def test_scalar_summary_with_ge_2():
log.debug("begin test_scalar_summary_with_ge_2") log.debug("begin test_scalar_summary_with_ge_2")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
# step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32))
net = SummaryDemo()
net.set_train()
# step 2: create the Event # step 1: create the network for summary
steps = 100
for i in range(1, steps):
x = Tensor(np.array([1.1]).astype(np.float32)) x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32)) y = Tensor(np.array([1.2]).astype(np.float32))
net(x, y) net = SummaryDemo()
test_writer.record(i) net.set_train()
# step 3: close the writer # step 2: create the Event
test_writer.close() steps = 100
for i in range(1, steps):
x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32))
net(x, y)
test_writer.record(i)
log.debug("finished test_scalar_summary_with_ge_2")
log.debug("finished test_scalar_summary_with_ge_2")
def test_validate():
sr = SummaryRecord(SUMMARY_DIR)
with pytest.raises(ValueError): def test_validate():
SummaryStep(sr, 0) with SummaryRecord(SUMMARY_DIR) as sr:
with pytest.raises(ValueError):
SummaryStep(sr, -1) with pytest.raises(ValueError):
with pytest.raises(ValueError): SummaryStep(sr, 0)
SummaryStep(sr, 1.2) with pytest.raises(ValueError):
with pytest.raises(ValueError): SummaryStep(sr, -1)
SummaryStep(sr, True) with pytest.raises(ValueError):
with pytest.raises(ValueError): SummaryStep(sr, 1.2)
SummaryStep(sr, "str") with pytest.raises(ValueError):
sr.record(1) SummaryStep(sr, True)
with pytest.raises(ValueError): with pytest.raises(ValueError):
sr.record(False) SummaryStep(sr, "str")
with pytest.raises(ValueError): sr.record(1)
sr.record(2.0) with pytest.raises(ValueError):
with pytest.raises(ValueError): sr.record(False)
sr.record((1, 3)) with pytest.raises(ValueError):
with pytest.raises(ValueError): sr.record(2.0)
sr.record([2, 3]) with pytest.raises(ValueError):
with pytest.raises(ValueError): sr.record((1, 3))
sr.record("str") with pytest.raises(ValueError):
with pytest.raises(ValueError): sr.record([2, 3])
sr.record(sr) with pytest.raises(ValueError):
sr.close() sr.record("str")
with pytest.raises(ValueError):
sr.record(sr)
SummaryStep(sr, 1) SummaryStep(sr, 1)
with pytest.raises(ValueError): with pytest.raises(ValueError):

@ -126,23 +126,21 @@ class HistogramSummaryNet(nn.Cell):
def run_case(net): def run_case(net):
""" run_case """ """ run_case """
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR) with SummaryRecord(SUMMARY_DIR) as test_writer:
# step 1: create the network for summary # step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32)) x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32)) y = Tensor(np.array([1.2]).astype(np.float32))
net.set_train() net.set_train()
# step 2: create the Event # step 2: create the Event
steps = 100 steps = 100
for i in range(1, steps): for i in range(1, steps):
x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32))
y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32))
net(x, y) net(x, y)
test_writer.record(i) test_writer.record(i)
# step 3: close the writer
test_writer.close()
# Test 1: use the repeat tag # Test 1: use the repeat tag

@ -80,19 +80,18 @@ def test_tensor_summary_sample():
""" test_tensor_summary_sample """ """ test_tensor_summary_sample """
log.debug("begin test_tensor_summary_sample") log.debug("begin test_tensor_summary_sample")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR") with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR") as test_writer:
# step 1: create the Event # step 1: create the Event
for i in range(1, 100): for i in range(1, 100):
test_data = get_test_data(i) test_data = get_test_data(i)
_cache_summary_tensor_data(test_data) _cache_summary_tensor_data(test_data)
test_writer.record(i) test_writer.record(i)
# step 2: accept the event and write the file # step 2: accept the event and write the file
test_writer.close()
log.debug("finished test_tensor_summary_sample") log.debug("finished test_tensor_summary_sample")
def get_test_data_check(step): def get_test_data_check(step):
@ -131,23 +130,20 @@ def test_tensor_summary_with_ge():
log.debug("begin test_tensor_summary_with_ge") log.debug("begin test_tensor_summary_with_ge")
# step 0: create the thread # step 0: create the thread
test_writer = SummaryRecord(SUMMARY_DIR) with SummaryRecord(SUMMARY_DIR) as test_writer:
# step 1: create the network for summary # step 1: create the network for summary
x = Tensor(np.array([1.1]).astype(np.float32)) x = Tensor(np.array([1.1]).astype(np.float32))
y = Tensor(np.array([1.2]).astype(np.float32)) y = Tensor(np.array([1.2]).astype(np.float32))
net = SummaryDemo() net = SummaryDemo()
net.set_train() net.set_train()
# step 2: create the Event # step 2: create the Event
steps = 100 steps = 100
for i in range(1, steps): for i in range(1, steps):
x = Tensor(np.array([[i], [i]]).astype(np.float32)) x = Tensor(np.array([[i], [i]]).astype(np.float32))
y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32)) y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32))
net(x, y) net(x, y)
test_writer.record(i) test_writer.record(i)
# step 3: close the writer log.debug("finished test_tensor_summary_with_ge")
test_writer.close()
log.debug("finished test_tensor_summary_with_ge")

Loading…
Cancel
Save