|
|
|
@ -43,7 +43,7 @@ __all__ = [
|
|
|
|
|
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
|
|
|
|
|
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
|
|
|
|
|
'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer',
|
|
|
|
|
'ExponentialMovingAverage', 'PipelineOptimizer'
|
|
|
|
|
'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -2953,3 +2953,156 @@ class PipelineOptimizer(object):
|
|
|
|
|
"sync_steps": self._sync_steps,
|
|
|
|
|
"param_need_sync": param_need_sync
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LookaheadOptimizer(object):
|
|
|
|
|
"""
|
|
|
|
|
This implements the Lookahead optimizer of the
|
|
|
|
|
paper : https://arxiv.org/abs/1907.08610.
|
|
|
|
|
|
|
|
|
|
Lookahead keeps two sets of params: the fast_params and
|
|
|
|
|
the slow_params. inner_optimizer update fast_params every
|
|
|
|
|
training step. Lookahead updates the slow_params and fast_params
|
|
|
|
|
every k training steps as follows:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
slow\_param_t &= slow\_param_{t-1} + \\alpha * (fast\_param_{t-1} - slow\_param_{t-1})
|
|
|
|
|
|
|
|
|
|
fast\_param_t &= slow\_param_t
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
inner_optimizer (Optimizer): The optimizer that update fast params step by step.
|
|
|
|
|
alpha (float): The learning rate of Lookahead.
|
|
|
|
|
k (int): The slow params is updated every k steps.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
|
|
|
|
|
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
|
|
|
|
|
y = fluid.layers.fc(input=[x], size=2, act="softmax")
|
|
|
|
|
loss = fluid.layers.cross_entropy(input=y, label=label)
|
|
|
|
|
loss = fluid.layers.mean(x=loss)
|
|
|
|
|
sgd = fluid.optimizer.SGD(learning_rate=0.01)
|
|
|
|
|
optimizer = fluid.optimizer.LookaheadOptimizer(sgd,
|
|
|
|
|
alpha=0.5,
|
|
|
|
|
k=5)
|
|
|
|
|
optimizer.minimize(loss)
|
|
|
|
|
main_program = fluid.default_main_program()
|
|
|
|
|
place = fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(fluid.default_startup_program())
|
|
|
|
|
|
|
|
|
|
feeder = fluid.DataFeeder(feed_list=[x, label], place=place)
|
|
|
|
|
|
|
|
|
|
step = 0
|
|
|
|
|
while(step < 10):
|
|
|
|
|
step += 1
|
|
|
|
|
exe.run(fluid.default_main_program(),
|
|
|
|
|
feed=feeder.feed(batch_data))
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, inner_optimizer, alpha=0.5, k=5):
|
|
|
|
|
|
|
|
|
|
assert (inner_optimizer is not None), "inner optimizer can not be None"
|
|
|
|
|
assert (
|
|
|
|
|
0.0 <= alpha <= 1.0
|
|
|
|
|
), "alpha should be larger or equal to 0.0, and less or equal than 1.0"
|
|
|
|
|
assert (isinstance(k, int) and k > 0), "k should be a positive integer"
|
|
|
|
|
|
|
|
|
|
self.inner_optimizer = inner_optimizer
|
|
|
|
|
self.alpha = alpha
|
|
|
|
|
self.k = k
|
|
|
|
|
self.type = "lookahead"
|
|
|
|
|
|
|
|
|
|
def minimize(self, loss, startup_program=None):
|
|
|
|
|
|
|
|
|
|
# Apply inner optimizer to the main_program
|
|
|
|
|
mini_out = self.inner_optimizer.minimize(
|
|
|
|
|
loss, startup_program=startup_program)
|
|
|
|
|
|
|
|
|
|
# Get startup_program and main_program
|
|
|
|
|
if startup_program is None:
|
|
|
|
|
startup_program = default_startup_program()
|
|
|
|
|
main_block = loss.block
|
|
|
|
|
|
|
|
|
|
# add some vars to the main_program
|
|
|
|
|
params = [param.name for param in main_block.all_parameters()]
|
|
|
|
|
param_to_slow = {}
|
|
|
|
|
for param in params:
|
|
|
|
|
fast_var = main_block.var(param)
|
|
|
|
|
assert (fast_var is not None)
|
|
|
|
|
slow_var = main_block.create_var(
|
|
|
|
|
name=param + "@SLOW",
|
|
|
|
|
shape=fast_var.shape,
|
|
|
|
|
dtype=fast_var.dtype,
|
|
|
|
|
persistable=True)
|
|
|
|
|
param_to_slow[param] = slow_var
|
|
|
|
|
|
|
|
|
|
# add some vars to the startup_program
|
|
|
|
|
startup_block = startup_program.global_block()
|
|
|
|
|
for param in params:
|
|
|
|
|
fast_var = startup_block.var(param)
|
|
|
|
|
assert (fast_var is not None)
|
|
|
|
|
slow_var = startup_block.create_var(
|
|
|
|
|
name=param + "@SLOW",
|
|
|
|
|
shape=fast_var.shape,
|
|
|
|
|
dtype=fast_var.dtype,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
startup_block.append_op(
|
|
|
|
|
type="assign",
|
|
|
|
|
inputs={"X": fast_var},
|
|
|
|
|
outputs={"Out": slow_var})
|
|
|
|
|
|
|
|
|
|
# Add Var k to main prog and startup prog
|
|
|
|
|
k = layers.create_global_var(
|
|
|
|
|
name="lookahead_k",
|
|
|
|
|
shape=[1],
|
|
|
|
|
value=int(self.k),
|
|
|
|
|
dtype='int32',
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
# Add Var alpha to main prog and startup prog
|
|
|
|
|
alpha = layers.create_global_var(
|
|
|
|
|
name="lookahead_alpha",
|
|
|
|
|
shape=[1],
|
|
|
|
|
value=float(self.alpha),
|
|
|
|
|
dtype='float32',
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
# Add Var step
|
|
|
|
|
step = layers.create_global_var(
|
|
|
|
|
name="lookahead_step",
|
|
|
|
|
shape=[1],
|
|
|
|
|
value=int(0),
|
|
|
|
|
dtype='int32',
|
|
|
|
|
persistable=True)
|
|
|
|
|
layers.increment(x=step, value=1.0, in_place=True)
|
|
|
|
|
|
|
|
|
|
# lookahead
|
|
|
|
|
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
|
|
|
|
|
|
|
|
|
|
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
|
|
|
|
|
|
|
|
|
|
mod = layers.elementwise_mod(step, k)
|
|
|
|
|
with layers.control_flow.Switch() as switch:
|
|
|
|
|
with switch.case(mod == zero_var):
|
|
|
|
|
for param_name in params:
|
|
|
|
|
fast_var = main_block.var(param_name)
|
|
|
|
|
slow_var = param_to_slow[param_name]
|
|
|
|
|
tmp_var = layers.elementwise_add(
|
|
|
|
|
layers.elementwise_mul(fast_var, alpha),
|
|
|
|
|
layers.elementwise_mul(
|
|
|
|
|
slow_var, layers.elementwise_sub(one_var, alpha)))
|
|
|
|
|
layers.assign(input=tmp_var, output=slow_var)
|
|
|
|
|
layers.assign(input=tmp_var, output=fast_var)
|
|
|
|
|
with switch.default():
|
|
|
|
|
pass
|
|
|
|
|
return mini_out
|
|
|
|
|