|
|
|
@ -65,12 +65,13 @@ def test_group_lr():
|
|
|
|
|
|
|
|
|
|
opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9)
|
|
|
|
|
assert opt.is_group is True
|
|
|
|
|
assert opt.is_group_lr is True
|
|
|
|
|
assert opt.dynamic_lr is False
|
|
|
|
|
for lr, param in zip(opt.learning_rate, opt.parameters):
|
|
|
|
|
if param in conv_params:
|
|
|
|
|
assert lr.data == Tensor(conv_lr, mstype.float32)
|
|
|
|
|
assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy())
|
|
|
|
|
else:
|
|
|
|
|
assert lr.data == Tensor(default_lr, mstype.float32)
|
|
|
|
|
assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy())
|
|
|
|
|
|
|
|
|
|
net_with_loss = WithLossCell(net, loss)
|
|
|
|
|
train_network = TrainOneStepCell(net_with_loss, opt)
|
|
|
|
@ -96,9 +97,9 @@ def test_group_dynamic_1():
|
|
|
|
|
assert opt.dynamic_lr is True
|
|
|
|
|
for lr, param in zip(opt.learning_rate, opt.parameters):
|
|
|
|
|
if param in conv_params:
|
|
|
|
|
assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32))
|
|
|
|
|
assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy())
|
|
|
|
|
else:
|
|
|
|
|
assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32))
|
|
|
|
|
assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy())
|
|
|
|
|
|
|
|
|
|
net_with_loss = WithLossCell(net, loss)
|
|
|
|
|
train_network = TrainOneStepCell(net_with_loss, opt)
|
|
|
|
@ -124,9 +125,9 @@ def test_group_dynamic_2():
|
|
|
|
|
assert opt.dynamic_lr is True
|
|
|
|
|
for lr, param in zip(opt.learning_rate, opt.parameters):
|
|
|
|
|
if param in conv_params:
|
|
|
|
|
assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))
|
|
|
|
|
assert np.all(lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)))
|
|
|
|
|
else:
|
|
|
|
|
assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))
|
|
|
|
|
assert np.all(lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)))
|
|
|
|
|
|
|
|
|
|
net_with_loss = WithLossCell(net, loss)
|
|
|
|
|
train_network = TrainOneStepCell(net_with_loss, opt)
|
|
|
|
@ -184,6 +185,7 @@ def test_weight_decay():
|
|
|
|
|
|
|
|
|
|
opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay)
|
|
|
|
|
assert opt.is_group is True
|
|
|
|
|
assert opt.is_group_lr is False
|
|
|
|
|
for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters):
|
|
|
|
|
if param in conv_params:
|
|
|
|
|
assert weight_decay == conv_weight_decay
|
|
|
|
|