|
|
|
@ -178,20 +178,18 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
|
|
|
|
net.trainable_params()))
|
|
|
|
|
no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params]
|
|
|
|
|
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
|
|
|
|
{'params': no_decayed_params},
|
|
|
|
|
{'params': no_decayed_params, 'weight_decay': 0.0},
|
|
|
|
|
{'order_params': net.trainable_params()}]
|
|
|
|
|
|
|
|
|
|
if config.use_lars:
|
|
|
|
|
momentum = nn.Momentum(group_params, lr, config.momentum,
|
|
|
|
|
weight_decay=config.weight_decay, loss_scale=config.loss_scale,
|
|
|
|
|
use_nesterov=config.use_nesterov)
|
|
|
|
|
loss_scale=config.loss_scale, use_nesterov=config.use_nesterov)
|
|
|
|
|
opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient,
|
|
|
|
|
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
opt = nn.Momentum(group_params, lr, config.momentum,
|
|
|
|
|
weight_decay=config.weight_decay, loss_scale=config.loss_scale,
|
|
|
|
|
use_nesterov=config.use_nesterov)
|
|
|
|
|
loss_scale=config.loss_scale, use_nesterov=config.use_nesterov)
|
|
|
|
|
|
|
|
|
|
# model
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt,
|
|
|
|
|