fix_bug_in_check_lamb_warmup_step

pull/1555/head
wangnan39@huawei.com 5 years ago
parent fb7e4eac76
commit 810ccf80d8

@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate, def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs.""" """Check the type of inputs."""
_ = warmup_steps
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
validator.check_float_positive('power', power, prim_name) validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
""" test lamb """ """ test lamb """
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
return x return x
def test_lamb_1(): def test_lamb_compile():
""" test_Lamb_1 """ """ test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) optimizer = Lamb(net.trainable_params(), decay_steps=10)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
def test_lamb_2(): def test_lamb_error():
""" test_Lamb_2 """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() with pytest.raises(TypeError):
loss = nn.SoftmaxCrossEntropyWithLogits() Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
net_with_loss = WithLossCell(net, loss) with pytest.raises(TypeError):
train_network = TrainOneStepCell(net_with_loss, optimizer) Lamb(net.get_parameters(), decay_steps=1.0)
_executor.compile(train_network, inputs, label)
with pytest.raises(ValueError):
Lamb(net.get_parameters(), decay_steps=0)

Loading…
Cancel
Save