|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
@ -2175,19 +2176,41 @@ class ExponentialMovingAverage(object):
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
\\text{EMA}_t = \\text{decay} * \\text{EMA}_{t-1} + (1 - \\text{decay}) * \\theta_t
|
|
|
|
|
\\text{EMA}_0 & = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The average results will be saved in temporary variables which can be
|
|
|
|
|
applied to parameters of current model by calling `apply()` method. And
|
|
|
|
|
the `restore()` method is used to restore the parameters.
|
|
|
|
|
\\text{EMA}_t & = \\text{decay} * \\text{EMA}_{t-1} + (1 - \\text{decay}) * \\theta_t
|
|
|
|
|
|
|
|
|
|
The average results will be saved in temporary variables which are created
|
|
|
|
|
and maintained by the object, and can be applied to parameters of current
|
|
|
|
|
model by calling **apply()** method. And the **restore()** method is used to
|
|
|
|
|
restore the parameters.
|
|
|
|
|
|
|
|
|
|
**Bias correction**. All EMAs are initialized to :math:`0` and hence they will be
|
|
|
|
|
zero biased, which can be corrected by divided by a factor
|
|
|
|
|
:math:`(1 - \\text{decay}^t)` , i.e., the actual EMAs applied to parameters
|
|
|
|
|
when calling **apply()** method would be
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
\\widehat{\\text{EMA}}_t = \\frac{\\text{EMA}_t}{1 - \\text{decay}^t}
|
|
|
|
|
|
|
|
|
|
**Decay rate scheduling**. A large decay rate very close to 1 would result
|
|
|
|
|
in that the averages move very slowly. And a better strategy is to set a
|
|
|
|
|
relative smaller decay rate in the very beginning. The argument **thres_steps**
|
|
|
|
|
allows users to pass a Variable to schedule the decay rate, in this case,
|
|
|
|
|
the actual decay rate becomes
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
\\min(\\text{decay}, \\frac{1 + \\text{thres_steps}}{10 + \\text{thres_steps}})
|
|
|
|
|
|
|
|
|
|
Usually **thres_steps** can be the global training steps.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
decay (float|Variable): The exponential decay rate. Can be scheduled like
|
|
|
|
|
learning rate.
|
|
|
|
|
zero_init (bool): Whether using zero to initialize EMA Variable. If set to
|
|
|
|
|
`True`, :math:`\\text{EMA}_0 = 0.0` else :math:`\\text{EMA}_0 = \\theta_0`.
|
|
|
|
|
decay (float): The exponential decay rate, usually close to 1, such as
|
|
|
|
|
0.999, 0.9999, ... .
|
|
|
|
|
thres_steps (Variable|None): If not `None`, schedule the decay rate.
|
|
|
|
|
name (str|None): An optional name prefix.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -2204,25 +2227,35 @@ class ExponentialMovingAverage(object):
|
|
|
|
|
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
|
|
|
|
|
optimizer.minimize(cost)
|
|
|
|
|
|
|
|
|
|
ema = fluid.optimizer.ExponentialMovingAverage(0.99)
|
|
|
|
|
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter()
|
|
|
|
|
ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps)
|
|
|
|
|
|
|
|
|
|
# pseudo code
|
|
|
|
|
for pass_id in range(args.pass_num):
|
|
|
|
|
for data in train_reader():
|
|
|
|
|
exe.run(fluid.default_main_program()...)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# usage 1
|
|
|
|
|
with ema.apply(exe):
|
|
|
|
|
for data in test_reader():
|
|
|
|
|
exe.run(inference_program...)
|
|
|
|
|
|
|
|
|
|
# usage 2
|
|
|
|
|
with ema.apply(exe, need_restore=False):
|
|
|
|
|
for data in test_reader():
|
|
|
|
|
exe.run(inference_program...)
|
|
|
|
|
...
|
|
|
|
|
ema.restore(exe)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, decay=0.999, zero_init=False, name=None):
|
|
|
|
|
def __init__(self, decay=0.999, thres_steps=None, name=None):
|
|
|
|
|
self._decay = decay
|
|
|
|
|
self._zero_init = zero_init
|
|
|
|
|
self._thres_steps = thres_steps
|
|
|
|
|
self._name = name if name is not None else ''
|
|
|
|
|
self._decay_var = self._get_ema_decay()
|
|
|
|
|
|
|
|
|
|
self.params_tmps = []
|
|
|
|
|
for param in framework.default_main_program().global_block(
|
|
|
|
|
).all_parameters():
|
|
|
|
|
for param in default_main_program().global_block().all_parameters():
|
|
|
|
|
if param.do_model_average != False:
|
|
|
|
|
tmp = param.block.create_var(
|
|
|
|
|
name=unique_name.generate(".".join(
|
|
|
|
@ -2232,22 +2265,23 @@ class ExponentialMovingAverage(object):
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
self.params_tmps.append((param, tmp))
|
|
|
|
|
|
|
|
|
|
startup_block = default_startup_program().global_block()
|
|
|
|
|
ema_vars = {}
|
|
|
|
|
for param, tmp in self.params_tmps:
|
|
|
|
|
with param.block.program._optimized_guard(
|
|
|
|
|
[param, tmp]), name_scope('moving_average'):
|
|
|
|
|
ema_vars[param.name] = self._append_ema_ops(startup_block,
|
|
|
|
|
param)
|
|
|
|
|
ema_vars[param.name] = self._append_ema_ops(param)
|
|
|
|
|
|
|
|
|
|
self.apply_program = Program()
|
|
|
|
|
block = self.apply_program.global_block()
|
|
|
|
|
with program_guard(main_program=self.apply_program):
|
|
|
|
|
decay_pow = self._get_decay_pow(block)
|
|
|
|
|
for param, tmp in self.params_tmps:
|
|
|
|
|
param = block._clone_variable(param)
|
|
|
|
|
tmp = block._clone_variable(tmp)
|
|
|
|
|
ema = block._clone_variable(ema_vars[param.name])
|
|
|
|
|
layers.assign(input=param, output=tmp)
|
|
|
|
|
# bias correction
|
|
|
|
|
ema = ema / (1.0 - decay_pow)
|
|
|
|
|
layers.assign(input=ema, output=param)
|
|
|
|
|
|
|
|
|
|
self.restore_program = Program()
|
|
|
|
@ -2258,25 +2292,43 @@ class ExponentialMovingAverage(object):
|
|
|
|
|
param = block._clone_variable(param)
|
|
|
|
|
layers.assign(input=tmp, output=param)
|
|
|
|
|
|
|
|
|
|
def _append_ema_ops(self, startup_block, param):
|
|
|
|
|
def _get_ema_decay(self):
|
|
|
|
|
with default_main_program()._lr_schedule_guard():
|
|
|
|
|
decay_var = layers.tensor.create_global_var(
|
|
|
|
|
shape=[1],
|
|
|
|
|
value=self._decay,
|
|
|
|
|
dtype='float32',
|
|
|
|
|
persistable=True,
|
|
|
|
|
name="scheduled_ema_decay_rate")
|
|
|
|
|
|
|
|
|
|
if self._thres_steps is not None:
|
|
|
|
|
decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0)
|
|
|
|
|
with layers.control_flow.Switch() as switch:
|
|
|
|
|
with switch.case(decay_t < self._decay):
|
|
|
|
|
layers.tensor.assign(decay_t, decay_var)
|
|
|
|
|
with switch.default():
|
|
|
|
|
layers.tensor.assign(
|
|
|
|
|
np.array(
|
|
|
|
|
[self._decay], dtype=np.float32),
|
|
|
|
|
decay_var)
|
|
|
|
|
return decay_var
|
|
|
|
|
|
|
|
|
|
def _get_decay_pow(self, block):
|
|
|
|
|
global_steps = layers.learning_rate_scheduler._decay_step_counter()
|
|
|
|
|
decay_var = block._clone_variable(self._decay_var)
|
|
|
|
|
decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1)
|
|
|
|
|
return decay_pow_acc
|
|
|
|
|
|
|
|
|
|
def _append_ema_ops(self, param):
|
|
|
|
|
param_ema = layers.create_global_var(
|
|
|
|
|
name=unique_name.generate(self._name + param.name + '_ema'),
|
|
|
|
|
shape=param.shape,
|
|
|
|
|
value=0.0,
|
|
|
|
|
dtype=param.dtype,
|
|
|
|
|
persistable=True)
|
|
|
|
|
# t = 0
|
|
|
|
|
if self._zero_init is not True:
|
|
|
|
|
startup_p_ema = startup_block._clone_variable(param_ema)
|
|
|
|
|
startup_p = startup_block.var(param.name)
|
|
|
|
|
startup_block.append_op(
|
|
|
|
|
type="assign",
|
|
|
|
|
inputs={"X": startup_p},
|
|
|
|
|
outputs={"Out": startup_p_ema})
|
|
|
|
|
# t > 0
|
|
|
|
|
ema_t = param_ema * self._decay - param * (self._decay - 1)
|
|
|
|
|
layers.assign(input=ema_t, output=param_ema)
|
|
|
|
|
|
|
|
|
|
ema_t = param_ema * self._decay_var + param * (1 - self._decay_var)
|
|
|
|
|
layers.assign(input=ema_t, output=param_ema)
|
|
|
|
|
return param_ema
|
|
|
|
|
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|