Polish code

test=develop
revert-16555-model_data_cryption_link_all_lib
minqiyang 6 years ago
parent 35c89f38c3
commit 48f3cbdf55

@ -350,7 +350,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
following cosine decay strategy. following cosine decay strategy.
decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1) decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1)
Args: Args:
learning_rate(Variable|float): The initial learning rate. learning_rate(Variable|float): The initial learning rate.
step_each_epoch(int): the number of steps in an epoch. step_each_epoch(int): the number of steps in an epoch.

@ -94,13 +94,18 @@ class Optimizer(object):
if imperative_base.enabled(): if imperative_base.enabled():
# create learning rate Variable # create learning rate Variable
if isinstance(self._learning_rate, float): if isinstance(self._learning_rate, float):
self._learning_rate_map[framework.default_main_program( lr = self._global_learning_rate()
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"), if isinstance(lr, framework.Variable):
shape=[1], return
value=float(self._learning_rate), else:
dtype='float32' if self._dtype is None else self._dtype, self._learning_rate_map[framework.default_main_program(
persistable=True) )] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32' if self._dtype is None else self._dtype,
persistable=True)
# get learning rate Variable from LearningRateDecay # get learning rate Variable from LearningRateDecay
elif isinstance(self._learning_rate, LearningRateDecay): elif isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(
@ -114,11 +119,12 @@ class Optimizer(object):
if isinstance(lr, framework.Variable): if isinstance(lr, framework.Variable):
return return
else:
if not isinstance(self._learning_rate, float): if not isinstance(self._learning_rate, float):
raise TypeError( raise TypeError(
"learning rate variable is create outside optimizer," "learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program") "can not create new learning rate variable for new program"
)
# create learning rate in the current main program # create learning rate in the current main program
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(

Loading…
Cancel
Save