@ -120,9 +120,9 @@ class TestLearningRateDecay(unittest.TestCase):
self . assertAlmostEqual (
python_decayed_lr ,
lr_val [ 0 ] ,
msg = ' Failed fn is {0} , Python result is {1} , Fluid result is {2 }' .
msg = ' Failed lr scheduler is {0} , step {1} , Python result is {2} , Fluid result is {3 }' .
format ( python_decay_fn . __name__ ,
str ( python_decayed_lr) , str ( lr_val [ 0 ] ) ) )
str ( step) , str ( python_decayed_lr) , str ( lr_val [ 0 ] ) ) )
def test_decay ( self ) :
common_kwargs_true = {
@ -164,12 +164,53 @@ class TestLearningRateDecay(unittest.TestCase):
]
for py_decay_fn , fluid_decay_fn , kwargs in decay_fns :
print ( " decay_fn= " + py_decay_fn . __name__ + " kwargs= " + str ( kwargs ) )
print ( " class= " + self . __class__ . __name__ + " decay_fn= " +
py_decay_fn . __name__ + " kwargs= " + str ( kwargs ) )
main_program = framework . Program ( )
startup_program = framework . Program ( )
with framework . program_guard ( main_program , startup_program ) :
self . check_decay ( py_decay_fn , fluid_decay_fn , kwargs )
def linear_lr_warmup ( global_step , warmup_steps , start_lr , end_lr ) :
linear_step = end_lr - start_lr
decayed_lr = start_lr + linear_step * ( global_step / warmup_steps )
return decayed_lr
class TestLinearWamrupLearningRateDecay ( TestLearningRateDecay ) :
def check_decay_with_place ( self , place , python_decay_fn , fluid_decay_fn ,
kwargs ) :
main_prog = fluid . Program ( )
startup_prog = fluid . Program ( )
warmup_steps = 10
start_lr = 1. / 3.
end_lr = 0.1
with fluid . program_guard ( main_prog , startup_prog ) :
decayed_lr = layers . linear_lr_warmup (
fluid_decay_fn ( * * kwargs ) , warmup_steps , start_lr , end_lr )
place = fluid . CPUPlace ( )
exe = fluid . Executor ( place )
exe . run ( startup_prog )
for step in range ( 20 ) :
lr_val , = exe . run ( main_prog , feed = { } , fetch_list = [ decayed_lr ] )
if step < warmup_steps :
python_decayed_lr = linear_lr_warmup (
float ( step ) , warmup_steps , start_lr , end_lr )
else :
python_decayed_lr = python_decay_fn (
global_step = float ( step ) , * * kwargs )
self . assertAlmostEqual (
python_decayed_lr ,
lr_val [ 0 ] ,
msg = ' Test {0} Failed, step {1} , Python result is {2} , Fluid result is {3} ' .
format ( python_decay_fn . __name__ ,
str ( step ) , str ( python_decayed_lr ) , str ( lr_val [ 0 ] ) ) )
if __name__ == ' __main__ ' :
unittest . main ( )