|
|
|
@ -14,6 +14,9 @@
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import numbers
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.distributed import ParallelEnv
|
|
|
|
@ -22,7 +25,8 @@ from paddle.utils import try_import
|
|
|
|
|
from .progressbar import ProgressBar
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL', 'LRScheduler'
|
|
|
|
|
'Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL', 'LRScheduler',
|
|
|
|
|
'EarlyStopping'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -45,6 +49,9 @@ def config_callbacks(callbacks=None,
|
|
|
|
|
if not any(isinstance(k, ModelCheckpoint) for k in cbks):
|
|
|
|
|
cbks = cbks + [ModelCheckpoint(save_freq, save_dir)]
|
|
|
|
|
|
|
|
|
|
for k in cbks:
|
|
|
|
|
if isinstance(k, EarlyStopping):
|
|
|
|
|
k.save_dir = save_dir
|
|
|
|
|
if not any(isinstance(k, LRScheduler) for k in cbks):
|
|
|
|
|
cbks = cbks + [LRScheduler()]
|
|
|
|
|
|
|
|
|
@ -581,6 +588,157 @@ class LRScheduler(Callback):
|
|
|
|
|
self.model._optimizer._learning_rate.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EarlyStopping(Callback):
|
|
|
|
|
"""Stop training when the given monitor stopped improving during evaluation.
|
|
|
|
|
Args:
|
|
|
|
|
monitor(str): Quantity to be monitored. Default: 'loss'.
|
|
|
|
|
mode(str|None): Mode should be one of 'auto', 'min' or 'max'. In 'min'
|
|
|
|
|
mode, training will stop until monitored quantity stops decreasing.
|
|
|
|
|
In 'max' mode, training will stop until monitored quantity stops
|
|
|
|
|
increasing. In 'auto' mode, exact mode can be inferred by the name
|
|
|
|
|
of monitor. If 'acc' in monitor, the mode will be considered as
|
|
|
|
|
'max', otherwise the mode will be set to 'min'. Default: 'auto'.
|
|
|
|
|
patience(int): Number of epochs with no improvement after which
|
|
|
|
|
training will be stopped. Default: 0.
|
|
|
|
|
verbose(int): The verbosity mode, should be 0 or 1. When verbose=0,
|
|
|
|
|
logs will not be printed. When verbose=1, logs will be printed.
|
|
|
|
|
Default: 1.
|
|
|
|
|
min_delta(int|float): The minimum change of monitored quantity. If
|
|
|
|
|
the change is less than min_delta, model could be considered as no
|
|
|
|
|
improvement. Default: 0.
|
|
|
|
|
baseline(int|float|None): Baseline value for the monitored quantity.
|
|
|
|
|
Training will stop if the model doesn't show improvement over the
|
|
|
|
|
baseline. Default: None.
|
|
|
|
|
save_best_model(bool): Whether to save best model. Default: True.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import Model
|
|
|
|
|
from paddle.static import InputSpec
|
|
|
|
|
from paddle.vision.models import LeNet
|
|
|
|
|
from paddle.vision.datasets import MNIST
|
|
|
|
|
from paddle.metric import Accuracy
|
|
|
|
|
from paddle.nn.layer.loss import CrossEntropyLoss
|
|
|
|
|
import paddle.vision.transforms as T
|
|
|
|
|
|
|
|
|
|
device = paddle.set_device('cpu')
|
|
|
|
|
sample_num = 200
|
|
|
|
|
save_dir = './best_model_checkpoint'
|
|
|
|
|
transform = T.Compose(
|
|
|
|
|
[T.Transpose(), T.Normalize([127.5], [127.5])])
|
|
|
|
|
train_dataset = MNIST(mode='train', transform=transform)
|
|
|
|
|
val_dataset = MNIST(mode='test', transform=transform)
|
|
|
|
|
net = LeNet()
|
|
|
|
|
optim = paddle.optimizer.Adam(
|
|
|
|
|
learning_rate=0.001, parameters=net.parameters())
|
|
|
|
|
|
|
|
|
|
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
|
|
|
|
|
labels = [InputSpec([None, 1], 'int64', 'label')]
|
|
|
|
|
|
|
|
|
|
model = Model(net, inputs=inputs, labels=labels)
|
|
|
|
|
model.prepare(
|
|
|
|
|
optim,
|
|
|
|
|
loss=CrossEntropyLoss(reduction="sum"),
|
|
|
|
|
metrics=[Accuracy()])
|
|
|
|
|
callbacks = paddle.callbacks.EarlyStopping(
|
|
|
|
|
'loss',
|
|
|
|
|
mode='min',
|
|
|
|
|
patience=1,
|
|
|
|
|
verbose=1,
|
|
|
|
|
min_delta=0,
|
|
|
|
|
baseline=None,
|
|
|
|
|
save_best_model=True)
|
|
|
|
|
model.fit(train_dataset,
|
|
|
|
|
val_dataset,
|
|
|
|
|
batch_size=64,
|
|
|
|
|
log_freq=200,
|
|
|
|
|
save_freq=10,
|
|
|
|
|
save_dir=save_dir,
|
|
|
|
|
epochs=20,
|
|
|
|
|
callbacks=[callbacks])
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
monitor='loss',
|
|
|
|
|
mode='auto',
|
|
|
|
|
patience=0,
|
|
|
|
|
verbose=1,
|
|
|
|
|
min_delta=0,
|
|
|
|
|
baseline=None,
|
|
|
|
|
save_best_model=True):
|
|
|
|
|
super(EarlyStopping, self).__init__()
|
|
|
|
|
self.monitor = monitor
|
|
|
|
|
self.patience = patience
|
|
|
|
|
self.verbose = verbose
|
|
|
|
|
self.baseline = baseline
|
|
|
|
|
self.min_delta = abs(min_delta)
|
|
|
|
|
self.wait_epoch = 0
|
|
|
|
|
self.best_weights = None
|
|
|
|
|
self.stopped_epoch = 0
|
|
|
|
|
self.save_best_model = save_best_model
|
|
|
|
|
self.save_dir = None # `save_dir` is get from `config_callbacks`
|
|
|
|
|
if mode not in ['auto', 'min', 'max']:
|
|
|
|
|
warnings.warn('EarlyStopping mode %s is unknown, '
|
|
|
|
|
'fallback to auto mode.' % mode)
|
|
|
|
|
mode = 'auto'
|
|
|
|
|
if mode == 'min':
|
|
|
|
|
self.monitor_op = np.less
|
|
|
|
|
elif mode == 'max':
|
|
|
|
|
self.monitor_op = np.greater
|
|
|
|
|
# When mode == 'auto', the mode should be inferred by `self.monitor`
|
|
|
|
|
else:
|
|
|
|
|
if 'acc' in self.monitor:
|
|
|
|
|
self.monitor_op = np.greater
|
|
|
|
|
else:
|
|
|
|
|
self.monitor_op = np.less
|
|
|
|
|
|
|
|
|
|
if self.monitor_op == np.greater:
|
|
|
|
|
self.min_delta *= 1
|
|
|
|
|
else:
|
|
|
|
|
self.min_delta *= -1
|
|
|
|
|
|
|
|
|
|
def on_train_begin(self, logs=None):
|
|
|
|
|
self.wait_epoch = 0
|
|
|
|
|
if self.baseline is not None:
|
|
|
|
|
self.best_value = self.baseline
|
|
|
|
|
else:
|
|
|
|
|
self.best_value = np.inf if self.monitor_op == np.less else -np.inf
|
|
|
|
|
self.best_weights = None
|
|
|
|
|
|
|
|
|
|
def on_eval_end(self, logs=None):
|
|
|
|
|
if logs is None or self.monitor not in logs:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
'Monitor of EarlyStopping should be loss or metric name.')
|
|
|
|
|
return
|
|
|
|
|
current = logs[self.monitor]
|
|
|
|
|
if isinstance(current, (list, tuple)):
|
|
|
|
|
current = current[0]
|
|
|
|
|
elif isinstance(current, numbers.Number):
|
|
|
|
|
current = current
|
|
|
|
|
else:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if self.monitor_op(current - self.min_delta, self.best_value):
|
|
|
|
|
self.best_value = current
|
|
|
|
|
self.wait_epoch = 0
|
|
|
|
|
if self.save_best_model and self.save_dir is not None:
|
|
|
|
|
path = os.path.join(self.save_dir, 'best_model')
|
|
|
|
|
self.model.save(path)
|
|
|
|
|
else:
|
|
|
|
|
self.wait_epoch += 1
|
|
|
|
|
if self.wait_epoch >= self.patience:
|
|
|
|
|
self.model.stop_training = True
|
|
|
|
|
if self.verbose > 0:
|
|
|
|
|
print('Epoch %d: Early stopping.' % (self.stopped_epoch + 1))
|
|
|
|
|
if self.save_best_model and self.save_dir is not None:
|
|
|
|
|
print('Best checkpoint has been saved at %s' %
|
|
|
|
|
(os.path.abspath(
|
|
|
|
|
os.path.join(self.save_dir, 'best_model'))))
|
|
|
|
|
self.stopped_epoch += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VisualDL(Callback):
|
|
|
|
|
"""VisualDL callback function
|
|
|
|
|
Args:
|
|
|
|
|