Merge pull request #15584 from velconia/imperative_lr_scheduler

Support imperative learning rate scheduler
revert-16555-model_data_cryption_link_all_lib
Qiyang Min 6 years ago committed by GitHub
commit d8d73ff3db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,6 +32,9 @@ from .profiler import *
from . import checkpoint
from .checkpoint import *
from . import learning_rate_scheduler
from .learning_rate_scheduler import *
__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
@ -39,3 +42,4 @@ __all__ += nn.__all__
__all__ += tracer.__all__
__all__ += profiler.__all__
__all__ += checkpoint.__all__
__all__ += learning_rate_scheduler.__all__

@ -0,0 +1,224 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
from .. import unique_name
__all__ = [
'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay'
]
class LearningRateDecay(object):
"""
Base class of learning rate decay
"""
def __init__(self, begin=0, step=1, dtype='float32'):
self.step_num = begin
self.step_size = step
self.dtype = dtype
def __call__(self):
lr = self.step()
if isinstance(lr, float):
lr = self.create_lr_var(lr)
self.step_num += self.step_size
return lr
def create_lr_var(self, lr):
from .. import layers
lr = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(lr),
dtype=self.dtype,
persistable=True)
return lr
def step(self):
raise NotImplementedError()
class PiecewiseDecay(LearningRateDecay):
def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
super(PiecewiseDecay, self).__init__(begin, step, dtype)
self.boundaries = boundaries
self.values = values
self.vars = []
for value in values:
self.vars.append(self.create_lr_var(value))
def step(self):
for i in range(len(self.boundaries)):
if self.step_num < self.boundaries[i]:
return self.vars[i]
return self.vars[len(self.values) - 1]
class NaturalExpDecay(LearningRateDecay):
def __init__(self,
learning_rate,
decay_steps,
decay_rate,
staircase=False,
begin=0,
step=1,
dtype='float32'):
super(NaturalExpDecay, self).__init__(begin, step, dtype)
self.learning_rate = learning_rate
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.staircase = staircase
def step(self):
from .. import layers
div_res = self.create_lr_var(self.step_num / self.decay_steps)
if self.staircase:
div_res = layers.floor(div_res)
decayed_lr = self.learning_rate * layers.exp(-1 * self.decay_rate *
div_res)
return decayed_lr
class ExponentialDecay(LearningRateDecay):
def __init__(self,
learning_rate,
decay_steps,
decay_rate,
staircase=False,
begin=0,
step=1,
dtype='float32'):
super(ExponentialDecay, self).__init__(begin, step, dtype)
self.learning_rate = learning_rate
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.staircase = staircase
def step(self):
from .. import layers
div_res = self.create_lr_var(self.step_num / self.decay_steps)
if self.staircase:
div_res = layers.floor(div_res)
decayed_lr = self.learning_rate * (self.decay_rate**div_res)
return decayed_lr
class InverseTimeDecay(LearningRateDecay):
def __init__(self,
learning_rate,
decay_steps,
decay_rate,
staircase=False,
begin=0,
step=1,
dtype='float32'):
super(InverseTimeDecay, self).__init__(begin, step, dtype)
self.learning_rate = learning_rate
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.staircase = staircase
def step(self):
from .. import layers
div_res = self.create_lr_var(self.step_num / self.decay_steps)
if self.staircase:
div_res = layers.floor(div_res)
decayed_lr = self.learning_rate / (1 + self.decay_rate * div_res)
return decayed_lr
class PolynomialDecay(LearningRateDecay):
def __init__(self,
learning_rate,
decay_steps,
end_learning_rate=0.0001,
power=1.0,
cycle=False,
begin=0,
step=1,
dtype='float32'):
super(PolynomialDecay, self).__init__(begin, step, dtype)
self.learning_rate = learning_rate
self.decay_steps = decay_steps
self.end_learning_rate = end_learning_rate
self.power = power
self.cycle = cycle
def step(self):
from .. import layers
tmp_step_num = self.step_num
tmp_decay_steps = self.decay_steps
if self.cycle:
div_res = layers.ceil(
self.create_lr_var(tmp_step_num / float(self.decay_steps)))
if tmp_step_num == 0:
div_res = self.create_lr_var(1.0)
tmp_decay_steps = self.decay_steps * div_res
else:
tmp_step_num = self.create_lr_var(tmp_step_num
if tmp_step_num < self.decay_steps
else self.decay_steps)
decayed_lr = (self.learning_rate - self.end_learning_rate) * \
((1 - tmp_step_num / tmp_decay_steps) ** self.power) + self.end_learning_rate
return decayed_lr
class CosineDecay(LearningRateDecay):
def __init__(self,
learning_rate,
step_each_epoch,
epochs,
begin=0,
step=1,
dtype='float32'):
super(CosineDecay, self).__init__(begin, step, dtype)
self.learning_rate = learning_rate
self.step_each_epoch = step_each_epoch
self.epochs = epochs
def step(self):
from .. import layers
cur_epoch = layers.floor(
self.create_lr_var(self.step_num / self.step_each_epoch))
decayed_lr = self.learning_rate * 0.5 * (
layers.cos(cur_epoch * math.pi / self.epochs) + 1)
return decayed_lr
class NoamDecay(LearningRateDecay):
def __init__(self, d_model, warmup_steps, begin=1, step=1, dtype='float32'):
super(NoamDecay, self).__init__(begin, step, dtype)
self.d_model = d_model
self.warmup_steps = warmup_steps
def step(self):
from .. import layers
a = self.create_lr_var(self.step_num**-0.5)
b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
lr_value = (self.d_model**-0.5) * layers.elementwise_min(a, b)
return lr_value

File diff suppressed because it is too large Load Diff

@ -30,6 +30,8 @@ from .initializer import Constant
from .layer_helper import LayerHelper
from .layers import ops
from .regularizer import append_regularization_ops
from .dygraph import base as imperative_base
from .dygraph.learning_rate_scheduler import LearningRateDecay
from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
@ -53,9 +55,19 @@ class Optimizer(object):
"""
def __init__(self, learning_rate, regularization=None, name=None):
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise TypeError("learning rate should be float or Variable")
if framework._in_dygraph_mode():
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, LearningRateDecay):
raise TypeError(
"learning rate should be float or LearningRateDecay, got %s here"
% type(learning_rate))
else:
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise TypeError(
"learning rate should be float or Variable, got %s here" %
type(learning_rate))
self._name = name
self.regularization = regularization
self._learning_rate = learning_rate
@ -79,24 +91,49 @@ class Optimizer(object):
return self._opti_name_list
def _create_global_learning_rate(self):
lr = self._global_learning_rate()
if imperative_base.enabled():
# create learning rate Variable
if isinstance(self._learning_rate, float):
lr = self._global_learning_rate()
if isinstance(lr, framework.Variable):
return
else:
if not isinstance(self._learning_rate, float):
if isinstance(lr, framework.Variable):
return
else:
self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32' if self._dtype is None else self._dtype,
persistable=True)
# get learning rate Variable from LearningRateDecay
elif isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate_map[framework.default_main_program(
)] = self._learning_rate()
else:
raise TypeError(
"learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program")
"optimizer's learning rate must be float or LearningRateDecay"
)
else:
lr = self._global_learning_rate()
# create learning rate in the current main program
self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32' if self._dtype is None else self._dtype,
persistable=True)
if isinstance(lr, framework.Variable):
return
else:
if not isinstance(self._learning_rate, float):
raise TypeError(
"learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program"
)
# create learning rate in the current main program
self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32' if self._dtype is None else self._dtype,
persistable=True)
def _global_learning_rate(self, program=None):
"""
@ -605,10 +642,10 @@ class DGCMomentumOptimizer(MomentumOptimizer):
DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication.
This optimizer will do two things:
1. Compress the gradient by get TopK import value from tensor \
and use it for allreduce to reduce network bandwidth.
2. Call momentum to optimize on the cost.
Args:

@ -78,7 +78,7 @@ list(REMOVE_ITEM TEST_OPS test_image_classification_resnet)
list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
list(REMOVE_ITEM TEST_OPS test_imperative_resnet)
list(REMOVE_ITEM TEST_OPS test_imperative_optimizer)
list(REMOVE_ITEM TEST_OPS test_imperative_mnist)
list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
@ -89,7 +89,7 @@ py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL)
py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS
FLAGS_cudnn_deterministic=1)
py_test_modules(test_imperative_optimizer MODULES test_imperative_optimizer ENVS
py_test_modules(test_imperative_mnist MODULES test_imperative_mnist ENVS
FLAGS_cudnn_deterministic=1)
if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)

@ -0,0 +1,217 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope)
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(),
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = self._fc(x)
return x
class TestImperativeMnist(unittest.TestCase):
def test_mnist_float32(self):
seed = 90
epoch_num = 1
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
dy_param_init_value = {}
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(128, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label._stop_gradient = True
cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
dy_out = avg_loss._numpy()
if epoch == 0 and batch_id == 0:
for param in mnist.parameters():
dy_param_init_value[param.name] = param._numpy()
avg_loss._backward()
sgd.minimize(avg_loss)
mnist.clear_gradients()
dy_param_value = {}
for param in mnist.parameters():
dy_param_value[param.name] = param._numpy()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
img = fluid.layers.data(
name='pixel', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
sgd.minimize(avg_loss)
# initialize params and fetch them
static_param_init_value = {}
static_param_name_list = []
for param in mnist.parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i]
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
static_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape([128, 1])
fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run(
fluid.default_main_program(),
feed={"pixel": static_x_data,
"label": y_data},
fetch_list=fetch_list)
static_param_value = {}
static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[
i]
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save