step lr_scheduler on epoch end in hapi/model.py fit (#27730)

* step lr_scheduler on epoch end in hapi/model.py fit. test=develop
swt-req
Kaipeng Deng 5 years ago committed by GitHub
parent d17681dcc7
commit b808979b8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1461,6 +1461,11 @@ class Model(object):
cbks.on_end('eval', eval_logs) cbks.on_end('eval', eval_logs)
# step learning rate scheduler on each epcoh end
if isinstance(self._optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
self._optimizer._learning_rate.step()
cbks.on_end('train', logs) cbks.on_end('train', logs)
self._test_dataloader = None self._test_dataloader = None

@ -33,7 +33,7 @@ from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.metric import Accuracy from paddle.metric import Accuracy
from paddle.vision.datasets import MNIST from paddle.vision.datasets import MNIST
from paddle.vision.models import LeNet from paddle.vision.models import LeNet
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler, Dataset
from paddle.hapi.model import prepare_distributed_context from paddle.hapi.model import prepare_distributed_context
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
@ -295,6 +295,15 @@ class MyModel(paddle.nn.Layer):
return y return y
class MyDataset(Dataset):
def __getitem__(self, idx):
return np.random.random(size=(20,)).astype(np.float32), \
np.random.randint(0, 10, size=(1,)).astype(np.int64)
def __len__(self):
return 40
class TestModelFunction(unittest.TestCase): class TestModelFunction(unittest.TestCase):
def set_seed(self, seed=1024): def set_seed(self, seed=1024):
paddle.manual_seed(seed) paddle.manual_seed(seed)
@ -599,6 +608,44 @@ class TestModelFunction(unittest.TestCase):
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
class TestModelWithLRScheduler(unittest.TestCase):
def test_fit(self):
def make_optimizer(parameters=None):
base_lr = 1e-3
momentum = 0.9
weight_decay = 5e-4
boundaries = [5, 8]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(
boundaries=boundaries, values=values)
learning_rate = paddle.optimizer.lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=4,
start_lr=base_lr / 5.,
end_lr=base_lr,
verbose=True)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
weight_decay=weight_decay,
momentum=momentum,
parameters=parameters)
return optimizer
device = paddle.set_device('cpu')
fluid.enable_dygraph(device)
net = MyModel()
inputs = [InputSpec([None, 20], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
optim = make_optimizer(net.parameters())
model = Model(net, inputs, labels)
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
dataset = MyDataset()
model.fit(dataset, dataset, batch_size=4, epochs=10, num_workers=0)
paddle.enable_static()
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
def test_input_without_name(self): def test_input_without_name(self):
net = MyModel() net = MyModel()

Loading…
Cancel
Save