|
|
|
@ -18,6 +18,7 @@ import os
|
|
|
|
|
import stat
|
|
|
|
|
import shutil
|
|
|
|
|
import time
|
|
|
|
|
from contextlib import ExitStack
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import mindspore.context as context
|
|
|
|
@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list):
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_callbacks(callbacks):
|
|
|
|
|
"""
|
|
|
|
|
Contain a list of callback.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
callbacks (list): Callback functions list, Support None, a single Callback object, or a list.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List, a list of callback functions.
|
|
|
|
|
"""
|
|
|
|
|
if callbacks:
|
|
|
|
|
if isinstance(callbacks, tuple):
|
|
|
|
|
raise TypeError("Callbacks cannot be a tuple. Please check it.")
|
|
|
|
|
if not isinstance(callbacks, list):
|
|
|
|
|
callbacks = [callbacks]
|
|
|
|
|
else:
|
|
|
|
|
callbacks = []
|
|
|
|
|
|
|
|
|
|
excute_callbacks = []
|
|
|
|
|
for cb in callbacks:
|
|
|
|
|
if cb is None or not isinstance(cb, Callback):
|
|
|
|
|
raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.")
|
|
|
|
|
excute_callbacks.append(cb)
|
|
|
|
|
|
|
|
|
|
return _ListCallback(excute_callbacks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ListCallback:
|
|
|
|
|
"""
|
|
|
|
|
Sequential execution of callback functions.
|
|
|
|
|
|
|
|
|
|
Execute Callback functions at certain points.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
callbacks (list): Callback functions list.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, callbacks):
|
|
|
|
|
super(_ListCallback, self).__init__()
|
|
|
|
|
self._callbacks = callbacks
|
|
|
|
|
|
|
|
|
|
def begin(self, run_context):
|
|
|
|
|
"""Called once before network training."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.begin(run_context)
|
|
|
|
|
|
|
|
|
|
def epoch_begin(self, run_context):
|
|
|
|
|
"""Called before each epoch begin."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.epoch_begin(run_context)
|
|
|
|
|
|
|
|
|
|
def epoch_end(self, run_context):
|
|
|
|
|
"""Called after each epoch finished."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.epoch_end(run_context)
|
|
|
|
|
|
|
|
|
|
def step_begin(self, run_context):
|
|
|
|
|
"""Called before each epoch begin."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.step_begin(run_context)
|
|
|
|
|
|
|
|
|
|
def step_end(self, run_context):
|
|
|
|
|
"""Called after each step finished."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.step_end(run_context)
|
|
|
|
|
|
|
|
|
|
def end(self, run_context):
|
|
|
|
|
"""Called once after network training."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.end(run_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Callback:
|
|
|
|
|
"""
|
|
|
|
|
Abstract base class used to build a callback function.
|
|
|
|
|
Abstract base class used to build a callback class. Callbacks are context managers
|
|
|
|
|
which will be entered and exited when passing into the Model.
|
|
|
|
|
You can leverage this mechanism to init and release resources automatically.
|
|
|
|
|
|
|
|
|
|
Callback function will execution some operating to the current step or epoch.
|
|
|
|
|
|
|
|
|
@ -369,8 +301,13 @@ class Callback:
|
|
|
|
|
>>> print_cb = Print_info()
|
|
|
|
|
>>> model.train(epoch, dataset, callbacks=print_cb)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
"""Return the enter target."""
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, *err):
|
|
|
|
|
"""Release resources here if have any."""
|
|
|
|
|
|
|
|
|
|
def begin(self, run_context):
|
|
|
|
|
"""
|
|
|
|
@ -421,6 +358,67 @@ class Callback:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _CallbackManager(Callback):
|
|
|
|
|
"""
|
|
|
|
|
Sequential execution of callback functions.
|
|
|
|
|
|
|
|
|
|
Execute Callback functions at certain points.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, callbacks):
|
|
|
|
|
self._callbacks, self._stack = [], None
|
|
|
|
|
if isinstance(callbacks, Callback):
|
|
|
|
|
self._callbacks.append(callbacks)
|
|
|
|
|
elif callbacks is not None:
|
|
|
|
|
for cb in callbacks:
|
|
|
|
|
if not isinstance(cb, Callback):
|
|
|
|
|
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
|
|
|
|
self._callbacks.append(cb)
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
if self._stack is None:
|
|
|
|
|
self._stack = ExitStack().__enter__()
|
|
|
|
|
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, *err):
|
|
|
|
|
return self._stack.__exit__(*err)
|
|
|
|
|
|
|
|
|
|
def begin(self, run_context):
|
|
|
|
|
"""Called once before network training."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.begin(run_context)
|
|
|
|
|
|
|
|
|
|
def epoch_begin(self, run_context):
|
|
|
|
|
"""Called before each epoch begin."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.epoch_begin(run_context)
|
|
|
|
|
|
|
|
|
|
def epoch_end(self, run_context):
|
|
|
|
|
"""Called after each epoch finished."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.epoch_end(run_context)
|
|
|
|
|
|
|
|
|
|
def step_begin(self, run_context):
|
|
|
|
|
"""Called before each epoch begin."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.step_begin(run_context)
|
|
|
|
|
|
|
|
|
|
def step_end(self, run_context):
|
|
|
|
|
"""Called after each step finished."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.step_end(run_context)
|
|
|
|
|
|
|
|
|
|
def end(self, run_context):
|
|
|
|
|
"""Called once after network training."""
|
|
|
|
|
for cb in self._callbacks:
|
|
|
|
|
cb.end(run_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SummaryStep(Callback):
|
|
|
|
|
"""
|
|
|
|
|
The summary callback class.
|
|
|
|
@ -435,6 +433,13 @@ class SummaryStep(Callback):
|
|
|
|
|
raise ValueError("`flush_step` should be int and greater than 0")
|
|
|
|
|
self._summary = summary
|
|
|
|
|
self._flush_step = flush_step
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
self._summary.__enter__()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, *err):
|
|
|
|
|
return self._summary.__exit__(*err)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def step_end(self, run_context):
|
|
|
|
|
"""
|
|
|
|
|