diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index 64f4adda99..c88af6bcf7 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -178,20 +178,18 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): net.trainable_params())) no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, - {'params': no_decayed_params}, + {'params': no_decayed_params, 'weight_decay': 0.0}, {'order_params': net.trainable_params()}] if config.use_lars: momentum = nn.Momentum(group_params, lr, config.momentum, - weight_decay=config.weight_decay, loss_scale=config.loss_scale, - use_nesterov=config.use_nesterov) + loss_scale=config.loss_scale, use_nesterov=config.use_nesterov) opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient, lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name) else: opt = nn.Momentum(group_params, lr, config.momentum, - weight_decay=config.weight_decay, loss_scale=config.loss_scale, - use_nesterov=config.use_nesterov) + loss_scale=config.loss_scale, use_nesterov=config.use_nesterov) # model model = Model(net, loss_fn=loss, optimizer=opt,