!239 Add dynamic learning rate decay and review optimizer code

Merge pull request !239 from fanglei/master
pull/239/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 60958d6b25

File diff suppressed because it is too large Load Diff

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""adam"""
from typing import Iterable
import numpy as np
from mindspore.common import dtype as mstype
@ -25,7 +24,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer, apply_decay, grad_scale
from .optimizer import Optimizer
_learning_rate_update_func = ['linear', 'cos', 'sin']
@ -168,22 +167,13 @@ class Adam(Optimizer):
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Adam, self).__init__(learning_rate, params)
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
_check_param_value(beta1, beta2, eps, weight_decay)
validator.check_type("use_locking", use_locking, [bool])
validator.check_type("use_nesterov", use_nesterov, [bool])
validator.check_type("loss_scale", loss_scale, [float])
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT)
self.dynamic_lr = False
if isinstance(learning_rate, Iterable) or \
(isinstance(learning_rate, Tensor) and learning_rate.dim() == 1):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32)
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
@ -196,8 +186,6 @@ class Adam(Optimizer):
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.weight_decay = weight_decay * loss_scale
self.reciprocal_scale = 1.0 / loss_scale
self.pow = P.Pow()
self.sqrt = P.Sqrt()
@ -208,15 +196,9 @@ class Adam(Optimizer):
params = self.parameters
moment1 = self.moment1
moment2 = self.moment2
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
lr = self.learning_rate
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, self.one))
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
beta1_power = self.beta1_power * self.beta1
self.beta1_power = beta1_power

@ -13,14 +13,9 @@
# limitations under the License.
# ============================================================================
"""momentum"""
from typing import Iterable
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
from mindspore.common import Tensor
from .optimizer import Optimizer, apply_decay, grad_scale
from .optimizer import Optimizer
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@ -88,43 +83,20 @@ class Momentum(Optimizer):
"""
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Momentum, self).__init__(learning_rate, params)
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
if isinstance(learning_rate, Iterable) or \
(isinstance(learning_rate, Tensor) and learning_rate.dim() == 1):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
else:
self.dynamic_lr = False
self.gather = None
self.assignadd = None
self.global_step = None
self.axis = None
self.momentum = Parameter(momentum, name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.weight_decay = weight_decay * loss_scale
self.reciprocal_scale = 1.0 / loss_scale
self.one = Tensor(1, mstype.int32)
def construct(self, gradients):
params = self.params
moments = self.moments
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, self.one))
else:
lr = self.learning_rate
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
return success

@ -17,9 +17,11 @@ from typing import Iterable
import numpy as np
import mindspore
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
@ -42,34 +44,110 @@ class Optimizer(Cell):
Args:
learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
parameters (list): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter.
should be class mindspore.Parameter.
weight_decay (float): A floating point value for the weight decay. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0. Should be greater than 0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda
x: 'beta' not in x.name and 'gamma' not in x.name.
Raises:
ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1.
TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable.
"""
def __init__(self, learning_rate, parameters):
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Optimizer, self).__init__()
if isinstance(learning_rate, float):
self.dynamic_lr = False
self.gather = None
self.assignadd = None
self.global_step = None
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT)
elif isinstance(learning_rate, Iterable):
learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32))
elif isinstance(learning_rate, Tensor):
if learning_rate.dim() > 1:
raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`,"
f"but got {learning_rate.dim()}.")
else:
raise TypeError("Learning rate should be float, Tensor or Iterable.")
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
if isinstance(learning_rate, Iterable):
learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32))
elif isinstance(learning_rate, Tensor):
if learning_rate.dim() > 1:
raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`,"
f"but got {learning_rate.dim()}.")
if learning_rate.dim() == 1 and learning_rate.size() < 2:
logger.warning("If want to use the dynamic learning rate, please make sure that the number "
"of elements in the list, tuple or tensor passed is greater than 1.")
else:
raise TypeError("Learning rate should be float, Tensor or Iterable.")
if loss_scale <= 0.0:
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
if weight_decay < 0.0:
raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1 and learning_rate.size() < 2:
logger.warning("If want to use the dynamic learning rate, please make sure that "
"the number of elements in the list, tuple or tensor passed is greater than 1.")
self.learning_rate = Parameter(learning_rate, name="learning_rate")
self.parameters = ParameterTuple(parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.weight_decay = weight_decay * loss_scale
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
if not self.parameters:
raise ValueError("optimizer got an empty parameter list.")
def decay_weight(self, gradients):
"""
Weight decay.
An approach to reduce the overfitting of a deep learning neural network model.
Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
`self.parameters`.
Returns:
tuple[Tensor], The gradients after weight decay.
"""
if self.weight_decay > 0:
params = self.params
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
return gradients
def scale_grad(self, gradients):
"""
Loss scale for mixed precision.
An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
network.
Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
`self.parameters`.
Returns:
tuple[Tensor], The gradients after loss scale.
"""
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
return gradients
def get_lr(self):
"""
Get the learning rate of current step.
Returns:
float, the learning rate of current step.
"""
lr = self.learning_rate
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, 0)
F.control_depend(lr, self.assignadd(self.global_step, 1))
return lr
def construct(self, *hyper_params):
raise NotImplementedError

@ -14,12 +14,8 @@
# ============================================================================
"""rmsprop"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator
import mindspore.common.dtype as mstype
from mindspore.common import Tensor
from .optimizer import Optimizer, grad_scale, apply_decay
from .optimizer import Optimizer
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@ -138,7 +134,7 @@ class RMSProp(Optimizer):
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(RMSProp, self).__init__(learning_rate, params)
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
@ -157,15 +153,6 @@ class RMSProp(Optimizer):
else:
self.opt = P.ApplyRMSProp(use_locking)
self.dynamic_lr = False
if not isinstance(learning_rate, float):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
self.one = Tensor(1, mstype.int32)
self.momentum = momentum
self.ms = self.parameters.clone(prefix="mean_square", init='zeros')
@ -173,21 +160,12 @@ class RMSProp(Optimizer):
self.hyper_map = C.HyperMap()
self.decay = decay
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.weight_decay = weight_decay * loss_scale
def construct(self, gradients):
params = self.parameters
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, self.one))
else:
lr = self.learning_rate
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.centered:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
self.momentum), params, self.mg, self.ms, self.moment, gradients)

@ -14,11 +14,9 @@
# ============================================================================
"""sgd"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator
import mindspore.common.dtype as mstype
from .optimizer import Optimizer, grad_scale
from .optimizer import Optimizer
sgd_opt = C.MultitypeFuncGraph("sgd_opt")
@ -83,7 +81,7 @@ class SGD(Optimizer):
def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False,
loss_scale=1.0):
super(SGD, self).__init__(learning_rate, params)
super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
@ -92,44 +90,22 @@ class SGD(Optimizer):
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
self.dampening = dampening
if weight_decay < 0.0:
raise ValueError("weight_decay should be at least 0.0, but got weight_decay {}".format(weight_decay))
self.weight_decay = weight_decay
validator.check_type("nesterov", nesterov, [bool])
self.nesterov = nesterov
self.opt = P.SGD(dampening, weight_decay, nesterov)
self.dynamic_lr = False
self.gather = None
self.global_step = None
self.axis = None
if not isinstance(learning_rate, float):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
self.momentum = Parameter(momentum, name="momentum")
self.params = self.parameters
self.accum = self.params.clone(prefix="accum", init='zeros')
self.stat = self.params.clone(prefix="stat", init='ones')
self.accum = self.parameters.clone(prefix="accum", init='zeros')
self.stat = self.parameters.clone(prefix="stat", init='ones')
self.hyper_map = C.HyperMap()
self.weight_decay = weight_decay * loss_scale
self.reciprocal_scale = 1.0 / loss_scale
def construct(self, gradients):
params = self.params
params = self.parameters
accum = self.accum
stat = self.stat
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1))
else:
lr = self.learning_rate
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat)
return success

@ -15,17 +15,11 @@
""" test optimizer """
import numpy as np
import pytest
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore import Tensor
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore.common.parameter import Parameter
gradient = Tensor(np.zeros([1, 2, 3]))
accumulation = gradient
variable = accumulation
paramsTensor = Tensor(np.zeros([1, 2, 3]))
class IterableObjc:
def __iter__(self):
cont = 0
@ -56,6 +50,7 @@ class TestAdam():
def test_construct(self):
with pytest.raises(TypeError):
gradient = Tensor(np.zeros([1, 2, 3]))
adam = Adam(params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
adam.construct(gradient)
@ -105,4 +100,5 @@ class TestUnsupportParam():
def test_Sgd_init(self):
with pytest.raises(TypeError):
paramsTensor = Tensor(np.zeros([1, 2, 3]))
SGD(paramsTensor)

@ -0,0 +1,234 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Test Dynamic Learning Rate """
import pytest
import mindspore
from mindspore.nn import dynamic_lr as dr
milestone = [10, 20, 30]
learning_rates = [0.1, 0.05, 0.01]
learning_rate = 0.1
end_learning_rate = 0.01
decay_rate = 0.9
total_step = 30
step_per_epoch = 3
decay_epoch = 2
min_lr = 0.01
max_lr = 0.1
power = 0.5
class TestInputs:
def test_milestone1(self):
milestone1 = 1
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone1, learning_rates)
def test_milestone2(self):
milestone1 = [20, 10, 1]
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone1, learning_rates)
milestone2 = [1.0, 2.0, True]
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone2, learning_rates)
def test_learning_rates1(self):
lr = True
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone, lr)
def test_learning_rates2(self):
lr = [1, 2, 1]
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone, lr)
def test_learning_rate_type(self):
lr = True
with pytest.raises(TypeError):
dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
with pytest.raises(TypeError):
dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
def test_learning_rate_value(self):
lr = -1.0
with pytest.raises(ValueError):
dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
with pytest.raises(ValueError):
dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
def test_end_learning_rate_type(self):
lr = True
with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power)
def test_end_learning_rate_value(self):
lr = -1.0
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power)
def test_decay_rate_type(self):
rate = 'a'
with pytest.raises(TypeError):
dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch)
def test_decay_rate_value(self):
rate = -1.0
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch)
def test_total_step1(self):
total_step1 = 2.0
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch)
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch)
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power)
def test_total_step2(self):
total_step1 = -1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch)
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch)
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power)
def test_step_per_epoch1(self):
step_per_epoch1 = True
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch)
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch)
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power)
def test_step_per_epoch2(self):
step_per_epoch1 = -1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch)
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch)
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power)
def test_decay_epoch1(self):
decay_epoch1 = 'm'
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1)
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1)
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power)
def test_decay_epoch2(self):
decay_epoch1 = -1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1)
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1)
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power)
def test_is_stair(self):
is_stair = 1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
def test_min_lr_type(self):
min_lr1 = True
with pytest.raises(TypeError):
dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch)
def test_min_lr_value(self):
min_lr1 = -1.0
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch)
def test_max_lr_type(self):
max_lr1 = 'a'
with pytest.raises(TypeError):
dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch)
def test_max_lr_value(self):
max_lr1 = -1.0
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch)
def test_power(self):
power1 = True
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1)
def test_update_decay_epoch(self):
update_decay_epoch = 1
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch,
power, update_decay_epoch)
def test_learning_rate():
lr = dr.piecewise_constant_lr(milestone, learning_rates)
assert len(lr) == milestone[-1]
def test_exponential_decay():
lr1 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
assert len(lr1) == total_step
lr2 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
assert len(lr2) == total_step
def test_enatural_exp_decay():
lr1 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
assert len(lr1) == total_step
lr2 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
assert len(lr2) == total_step
def test_inverse_decay():
lr1 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
assert len(lr1) == total_step
lr2 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
assert len(lr2) == total_step
def test_cosine_decay():
lr = dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
assert len(lr) == total_step
def test_polynomial_decay():
lr1 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
assert len(lr1) == total_step
lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
True)
assert len(lr2) == total_step
Loading…
Cancel
Save