|
|
@ -3570,8 +3570,10 @@ class ExponentialMovingAverage(object):
|
|
|
|
# bias correction
|
|
|
|
# bias correction
|
|
|
|
with layers.control_flow.Switch() as switch:
|
|
|
|
with layers.control_flow.Switch() as switch:
|
|
|
|
with switch.case(global_step > 0):
|
|
|
|
with switch.case(global_step > 0):
|
|
|
|
layers.assign(output=ema, input=ema / (1.0 - decay_pow))
|
|
|
|
layers.assign(
|
|
|
|
layers.assign(input=ema, output=param)
|
|
|
|
output=param, input=ema / (1.0 - decay_pow))
|
|
|
|
|
|
|
|
with switch.default():
|
|
|
|
|
|
|
|
layers.assign(output=param, input=ema)
|
|
|
|
|
|
|
|
|
|
|
|
self.restore_program = Program()
|
|
|
|
self.restore_program = Program()
|
|
|
|
block = self.restore_program.global_block()
|
|
|
|
block = self.restore_program.global_block()
|
|
|
|