|
|
|
@ -15,12 +15,15 @@
|
|
|
|
|
import os
|
|
|
|
|
import numbers
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.dygraph.parallel import ParallelEnv
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.distributed import ParallelEnv
|
|
|
|
|
from paddle.utils import try_import
|
|
|
|
|
|
|
|
|
|
from .progressbar import ProgressBar
|
|
|
|
|
|
|
|
|
|
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL']
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL', 'LRScheduler'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def config_callbacks(callbacks=None,
|
|
|
|
@ -42,6 +45,9 @@ def config_callbacks(callbacks=None,
|
|
|
|
|
if not any(isinstance(k, ModelCheckpoint) for k in cbks):
|
|
|
|
|
cbks = cbks + [ModelCheckpoint(save_freq, save_dir)]
|
|
|
|
|
|
|
|
|
|
if not any(isinstance(k, LRScheduler) for k in cbks):
|
|
|
|
|
cbks = cbks + [LRScheduler()]
|
|
|
|
|
|
|
|
|
|
cbk_list = CallbackList(cbks)
|
|
|
|
|
cbk_list.set_model(model)
|
|
|
|
|
metrics = metrics or [] if mode != 'test' else []
|
|
|
|
@ -485,6 +491,96 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
self.model.save(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LRScheduler(Callback):
|
|
|
|
|
"""Lr scheduler callback function
|
|
|
|
|
Args:
|
|
|
|
|
by_step(bool, optional): whether to update learning rate scheduler
|
|
|
|
|
by step. Default: True.
|
|
|
|
|
by_epoch(bool, optional): whether to update learning rate scheduler
|
|
|
|
|
by epoch. Default: False.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.vision.transforms as T
|
|
|
|
|
from paddle.static import InputSpec
|
|
|
|
|
|
|
|
|
|
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
|
|
|
|
|
labels = [InputSpec([None, 1], 'int64', 'label')]
|
|
|
|
|
|
|
|
|
|
transform = T.Compose([
|
|
|
|
|
T.Transpose(),
|
|
|
|
|
T.Normalize([127.5], [127.5])
|
|
|
|
|
])
|
|
|
|
|
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
|
|
|
|
|
|
|
|
|
|
lenet = paddle.vision.LeNet()
|
|
|
|
|
model = paddle.Model(lenet,
|
|
|
|
|
inputs, labels)
|
|
|
|
|
|
|
|
|
|
base_lr = 1e-3
|
|
|
|
|
boundaries = [5, 8]
|
|
|
|
|
wamup_steps = 4
|
|
|
|
|
|
|
|
|
|
def make_optimizer(parameters=None):
|
|
|
|
|
momentum = 0.9
|
|
|
|
|
weight_decay = 5e-4
|
|
|
|
|
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
|
|
|
|
|
learning_rate = paddle.optimizer.lr.PiecewiseDecay(
|
|
|
|
|
boundaries=boundaries, values=values)
|
|
|
|
|
learning_rate = paddle.optimizer.lr.LinearWarmup(
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
|
warmup_steps=wamup_epochs,
|
|
|
|
|
start_lr=base_lr / 5.,
|
|
|
|
|
end_lr=base_lr,
|
|
|
|
|
verbose=True)
|
|
|
|
|
optimizer = paddle.optimizer.Momentum(
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
|
weight_decay=weight_decay,
|
|
|
|
|
momentum=momentum,
|
|
|
|
|
parameters=parameters)
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
optim = make_optimizer(parameters=lenet.parameters())
|
|
|
|
|
model.prepare(optimizer=optim,
|
|
|
|
|
loss=paddle.nn.CrossEntropyLoss(),
|
|
|
|
|
metrics=paddle.metric.Accuracy())
|
|
|
|
|
|
|
|
|
|
# if LRScheduler callback not set, an instance LRScheduler update by step
|
|
|
|
|
# will be created auto.
|
|
|
|
|
model.fit(train_dataset, batch_size=64)
|
|
|
|
|
|
|
|
|
|
# create a learning rate scheduler update by epoch
|
|
|
|
|
callback = paddle.callbacks.LRScheduler(by_step=False, by_epoch=True)
|
|
|
|
|
model.fit(train_dataset, batch_size=64, callbacks=callback)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, by_step=True, by_epoch=False):
|
|
|
|
|
if by_step and by_epoch:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"by_step option is mutually exclusive with by_epoch")
|
|
|
|
|
|
|
|
|
|
self.by_step = by_step
|
|
|
|
|
self.by_epoch = by_epoch
|
|
|
|
|
|
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
|
|
|
if self.by_epoch:
|
|
|
|
|
if self.model._optimizer and \
|
|
|
|
|
hasattr(self.model._optimizer, '_learning_rate') and \
|
|
|
|
|
isinstance(self.model._optimizer._learning_rate,
|
|
|
|
|
paddle.optimizer.lr.LRScheduler):
|
|
|
|
|
self.model._optimizer._learning_rate.step()
|
|
|
|
|
|
|
|
|
|
def on_train_batch_end(self, step, logs=None):
|
|
|
|
|
if self.by_step:
|
|
|
|
|
if self.model._optimizer and \
|
|
|
|
|
hasattr(self.model._optimizer, '_learning_rate') and \
|
|
|
|
|
isinstance(self.model._optimizer._learning_rate,
|
|
|
|
|
paddle.optimizer.lr.LRScheduler):
|
|
|
|
|
self.model._optimizer._learning_rate.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VisualDL(Callback):
|
|
|
|
|
"""VisualDL callback function
|
|
|
|
|
Args:
|
|
|
|
|