@ -68,14 +68,16 @@ class Optimizer(object):
regularization = None ,
grad_clip = None ,
name = None ) :
# Because of the loop import, so place it in the function body
from paddle . optimizer . lr_scheduler import _LRScheduler
self . _parameter_list = list (
parameter_list ) if parameter_list is not None else None
self . _name = name
if framework . in_dygraph_mode ( ) :
if not isinstance ( learning_rate , float ) and \
not isinstance ( learning_rate , LearningRateDecay ) :
if not isinstance ( learning_rate ,
( float , LearningRateDecay , _LRScheduler ) ) :
raise TypeError (
" learning rate should be float or LearningRateDecay , got %s here "
" learning rate should be float or _LRScheduler , got %s here "
% type ( learning_rate ) )
if self . _parameter_list is None :
raise AttributeError (
@ -90,11 +92,11 @@ class Optimizer(object):
% regularization . __str__ ( ) )
break
else :
if not isinstance ( learning_rate , float ) and \
not isinstance ( learning_rate , framework . Variable ) :
if not isinstance ( learning_rate ,
( float , framework . Variable , _LRScheduler ) ) :
raise TypeError (
" learning rate should be float or Variable, got %s here " %
type ( learning_rate ) )
" learning rate should be float or _LRScheduler, got %s here "
% type ( learning_rate ) )
if grad_clip is not None :
if not isinstance ( grad_clip , GradientClipBase ) :
@ -144,11 +146,15 @@ class Optimizer(object):
state_dict = adam . state_dict ( )
'''
from paddle . optimizer . lr_scheduler import _LRScheduler
state_dict = { }
for k , v in self . _accumulators . items ( ) :
for para_name , var_tmp in v . items ( ) :
state_dict [ var_tmp . name ] = var_tmp
# global step if use lr decay
if isinstance ( self . _learning_rate , _LRScheduler ) :
state_dict [ " LR_Scheduler " ] = self . _learning_rate . state_dict ( )
return state_dict
if isinstance ( self . _learning_rate , LearningRateDecay ) :
state_dict [ " LR_Scheduler " ] = self . _learning_rate . state_dict ( )
@ -192,6 +198,9 @@ class Optimizer(object):
adam . set_dict ( opti_state_dict )
'''
from paddle . optimizer . lr_scheduler import _LRScheduler
if isinstance ( self . _learning_rate , _LRScheduler ) :
self . _learning_rate . set_dict ( state_dict [ " LR_Scheduler " ] )
if isinstance ( self . _learning_rate , LearningRateDecay ) :
self . _learning_rate . set_dict ( state_dict [ " LR_Scheduler " ] )
@ -252,6 +261,30 @@ class Optimizer(object):
return self . _opti_name_list
def _create_global_learning_rate ( self ) :
from paddle . optimizer . lr_scheduler import _LRScheduler
if isinstance ( self . _learning_rate , _LRScheduler ) :
lr_var = self . _global_learning_rate ( )
# only create global lr_var once
if not isinstance ( lr_var , framework . Variable ) :
lr_name = unique_name . generate ( ' learning_rate ' )
self . _learning_rate . _var_name = lr_name
lr_var = self . helper . create_global_variable (
name = lr_name ,
shape = [ 1 ] ,
persistable = True ,
stop_gradient = True ,
dtype = ' float32 ' if self . _dtype is None else self . _dtype )
main_prog = framework . default_main_program ( )
main_prog . lr_sheduler = self . _learning_rate
main_prog . lr_var = lr_var
self . _learning_rate_map [ framework . default_main_program (
) ] = lr_var
lr_value = float ( self . _learning_rate ( ) )
self . helper . set_variable_initializer (
lr_var , initializer = Constant ( value = lr_value ) )
return
if imperative_base . enabled ( ) :
# create learning rate Variable
if isinstance ( self . _learning_rate , float ) :