|
|
@ -28,7 +28,7 @@ def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
|
|
|
|
if weights is None:
|
|
|
|
if weights is None:
|
|
|
|
weights = ParameterTuple(net.trainable_params())
|
|
|
|
weights = ParameterTuple(net.trainable_params())
|
|
|
|
if optimizer is None:
|
|
|
|
if optimizer is None:
|
|
|
|
optimizer = nn.Adam(weights, learning_rate=1e-2, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
|
|
|
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
|
|
|
|
use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
|
|
|
|
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
|
|
|
|
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
|
|
|
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
|
|
|
return train_net
|
|
|
|
return train_net
|
|
|
|