|
|
|
@ -13,12 +13,14 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import numbers
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.dygraph.parallel import ParallelEnv
|
|
|
|
|
from paddle.utils import try_import
|
|
|
|
|
|
|
|
|
|
from .progressbar import ProgressBar
|
|
|
|
|
|
|
|
|
|
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint']
|
|
|
|
|
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def config_callbacks(callbacks=None,
|
|
|
|
@ -471,3 +473,111 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
path = '{}/final'.format(self.save_dir)
|
|
|
|
|
print('save checkpoint at {}'.format(os.path.abspath(path)))
|
|
|
|
|
self.model.save(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VisualDL(Callback):
|
|
|
|
|
"""VisualDL callback function
|
|
|
|
|
Args:
|
|
|
|
|
log_dir (str): The directory to save visualdl log file.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.static import InputSpec
|
|
|
|
|
|
|
|
|
|
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
|
|
|
|
|
labels = [InputSpec([None, 1], 'int64', 'label')]
|
|
|
|
|
|
|
|
|
|
train_dataset = paddle.vision.datasets.MNIST(mode='train')
|
|
|
|
|
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
|
|
|
|
|
|
|
|
|
|
net = paddle.vision.LeNet()
|
|
|
|
|
model = paddle.Model(net, inputs, labels)
|
|
|
|
|
|
|
|
|
|
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
|
|
|
|
|
model.prepare(optimizer=optim,
|
|
|
|
|
loss=paddle.nn.CrossEntropyLoss(),
|
|
|
|
|
metrics=paddle.metric.Accuracy())
|
|
|
|
|
|
|
|
|
|
## uncomment following lines to fit model with visualdl callback function
|
|
|
|
|
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
|
|
|
|
|
# model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, log_dir):
|
|
|
|
|
self.log_dir = log_dir
|
|
|
|
|
self.epochs = None
|
|
|
|
|
self.steps = None
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
|
|
|
|
|
def _is_write(self):
|
|
|
|
|
return ParallelEnv().local_rank == 0
|
|
|
|
|
|
|
|
|
|
def on_train_begin(self, logs=None):
|
|
|
|
|
self.epochs = self.params['epochs']
|
|
|
|
|
assert self.epochs
|
|
|
|
|
self.train_metrics = self.params['metrics']
|
|
|
|
|
assert self.train_metrics
|
|
|
|
|
self._is_fit = True
|
|
|
|
|
self.train_step = 0
|
|
|
|
|
|
|
|
|
|
def on_epoch_begin(self, epoch=None, logs=None):
|
|
|
|
|
self.steps = self.params['steps']
|
|
|
|
|
self.epoch = epoch
|
|
|
|
|
|
|
|
|
|
def _updates(self, logs, mode):
|
|
|
|
|
if not self._is_write():
|
|
|
|
|
return
|
|
|
|
|
if not hasattr(self, 'writer'):
|
|
|
|
|
visualdl = try_import('visualdl')
|
|
|
|
|
self.writer = visualdl.LogWriter(self.log_dir)
|
|
|
|
|
|
|
|
|
|
metrics = getattr(self, '%s_metrics' % (mode))
|
|
|
|
|
current_step = getattr(self, '%s_step' % (mode))
|
|
|
|
|
|
|
|
|
|
if mode == 'train':
|
|
|
|
|
total_step = current_step
|
|
|
|
|
else:
|
|
|
|
|
total_step = self.epoch
|
|
|
|
|
|
|
|
|
|
for k in metrics:
|
|
|
|
|
if k in logs:
|
|
|
|
|
temp_tag = mode + '/' + k
|
|
|
|
|
|
|
|
|
|
if isinstance(logs[k], (list, tuple)):
|
|
|
|
|
temp_value = logs[k][0]
|
|
|
|
|
elif isinstance(logs[k], numbers.Number):
|
|
|
|
|
temp_value = logs[k]
|
|
|
|
|
else:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
self.writer.add_scalar(
|
|
|
|
|
tag=temp_tag, step=total_step, value=temp_value)
|
|
|
|
|
|
|
|
|
|
def on_train_batch_end(self, step, logs=None):
|
|
|
|
|
logs = logs or {}
|
|
|
|
|
self.train_step += 1
|
|
|
|
|
|
|
|
|
|
if self._is_write():
|
|
|
|
|
self._updates(logs, 'train')
|
|
|
|
|
|
|
|
|
|
def on_eval_begin(self, logs=None):
|
|
|
|
|
self.eval_steps = logs.get('steps', None)
|
|
|
|
|
self.eval_metrics = logs.get('metrics', [])
|
|
|
|
|
self.eval_step = 0
|
|
|
|
|
self.evaled_samples = 0
|
|
|
|
|
|
|
|
|
|
def on_train_end(self, logs=None):
|
|
|
|
|
if hasattr(self, 'writer'):
|
|
|
|
|
self.writer.close()
|
|
|
|
|
delattr(self, 'writer')
|
|
|
|
|
|
|
|
|
|
def on_eval_end(self, logs=None):
|
|
|
|
|
if self._is_write():
|
|
|
|
|
self._updates(logs, 'eval')
|
|
|
|
|
|
|
|
|
|
if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
|
|
|
|
|
self.writer.close()
|
|
|
|
|
delattr(self, 'writer')
|
|
|
|
|