|
|
|
@ -212,3 +212,41 @@ def test_group_repeat_param():
|
|
|
|
|
{'params': no_conv_params}]
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
|
Adam(group_params, learning_rate=default_lr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_lr_parameter_with_group():
|
|
|
|
|
net = LeNet5()
|
|
|
|
|
conv_lr = 0.1
|
|
|
|
|
default_lr = 0.3
|
|
|
|
|
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
|
|
|
|
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
|
|
|
|
group_params = [{'params': conv_params, 'lr': conv_lr},
|
|
|
|
|
{'params': no_conv_params, 'lr': default_lr}]
|
|
|
|
|
opt = SGD(group_params)
|
|
|
|
|
assert opt.is_group_lr is True
|
|
|
|
|
for param in opt.parameters:
|
|
|
|
|
lr = opt.get_lr_parameter(param)
|
|
|
|
|
assert lr.name == 'lr_' + param.name
|
|
|
|
|
|
|
|
|
|
lr_list = opt.get_lr_parameter(conv_params)
|
|
|
|
|
for lr, param in zip(lr_list, conv_params):
|
|
|
|
|
assert lr.name == 'lr_' + param.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_lr_parameter_with_no_group():
|
|
|
|
|
net = LeNet5()
|
|
|
|
|
conv_weight_decay = 0.8
|
|
|
|
|
|
|
|
|
|
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
|
|
|
|
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
|
|
|
|
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
|
|
|
|
|
{'params': no_conv_params}]
|
|
|
|
|
opt = SGD(group_params)
|
|
|
|
|
assert opt.is_group_lr is False
|
|
|
|
|
for param in opt.parameters:
|
|
|
|
|
lr = opt.get_lr_parameter(param)
|
|
|
|
|
assert lr.name == opt.learning_rate.name
|
|
|
|
|
|
|
|
|
|
params_error = [1, 2, 3]
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
opt.get_lr_parameter(params_error)
|
|
|
|
|