@ -614,7 +614,7 @@ class TestLookaheadOptimizer(unittest.TestCase):
class TestRecomputeOptimizer ( unittest . TestCase ) :
def net ( self ):
def net ( self , return_input = False ):
program = framework . Program ( )
block = program . global_block ( )
mul_x = block . create_parameter (
@ -652,6 +652,8 @@ class TestRecomputeOptimizer(unittest.TestCase):
block . append_op (
type = " mean " , inputs = { " X " : b2_out } , outputs = { " Out " : mean_out } )
if return_input == True :
return mul_x , mul_out , b1_out , b2_out , mean_out
return mul_out , b1_out , b2_out , mean_out
def test_no_checkpoint ( self ) :
@ -723,6 +725,42 @@ class TestRecomputeOptimizer(unittest.TestCase):
" elementwise_add_grad " , " mul_grad " , " sgd " , " sgd " , " sgd "
] )
def test_out_of_order_checkpoint ( self ) :
mul_out , b1_out , b2_out , mean_out = self . net ( )
self . assertEqual ( len ( mean_out . block . ops ) , 4 )
self . assertEqual ( [ op . type for op in mean_out . block . ops ] ,
[ " mul " , " elementwise_add " , " elementwise_add " , " mean " ] )
sgd_optimizer = optimizer . SGD ( learning_rate = 1.0 )
recompute_optimizer = optimizer . RecomputeOptimizer ( sgd_optimizer )
recompute_optimizer . _set_checkpoints ( [ b2_out , mul_out ] )
opts , params_grads = recompute_optimizer . minimize ( mean_out )
self . assertEqual ( len ( mean_out . block . ops ) , 13 )
self . assertEqual ( [ op . type for op in mean_out . block . ops ] , [
" mul " , " elementwise_add " , " elementwise_add " , " mean " ,
" fill_constant " , " mean_grad " , " elementwise_add " ,
" elementwise_add_grad " , " elementwise_add_grad " , " mul_grad " , " sgd " ,
" sgd " , " sgd "
] )
def test_input_as_checkpoints ( self ) :
mul_x , mul_out , b1_out , b2_out , mean_out = self . net ( return_input = True )
self . assertEqual ( len ( mean_out . block . ops ) , 4 )
self . assertEqual ( [ op . type for op in mean_out . block . ops ] ,
[ " mul " , " elementwise_add " , " elementwise_add " , " mean " ] )
sgd_optimizer = optimizer . SGD ( learning_rate = 1.0 )
recompute_optimizer = optimizer . RecomputeOptimizer ( sgd_optimizer )
recompute_optimizer . _set_checkpoints ( [ mul_x , b2_out ] )
opts , params_grads = recompute_optimizer . minimize ( mean_out )
self . assertEqual ( len ( mean_out . block . ops ) , 14 )
self . assertEqual ( [ op . type for op in mean_out . block . ops ] , [
" mul " , " elementwise_add " , " elementwise_add " , " mean " ,
" fill_constant " , " mean_grad " , " mul " , " elementwise_add " ,
" elementwise_add_grad " , " elementwise_add_grad " , " mul_grad " , " sgd " ,
" sgd " , " sgd "
] )
def test_apply_gradients ( self ) :
mul_out , b1_out , b2_out , mean_out = self . net ( )
sgd_optimizer = optimizer . SGD ( learning_rate = 1.0 )