|
|
@ -28,7 +28,7 @@ from mindspore.common import set_seed
|
|
|
|
|
|
|
|
|
|
|
|
from src.config import nasnet_a_mobile_config_gpu as cfg
|
|
|
|
from src.config import nasnet_a_mobile_config_gpu as cfg
|
|
|
|
from src.dataset import create_dataset
|
|
|
|
from src.dataset import create_dataset
|
|
|
|
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient
|
|
|
|
from src.nasnet_a_mobile import NASNetAMobile, CrossEntropy
|
|
|
|
from src.lr_generator import get_lr
|
|
|
|
from src.lr_generator import get_lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -68,10 +68,13 @@ if __name__ == '__main__':
|
|
|
|
batches_per_epoch = dataset.get_dataset_size()
|
|
|
|
batches_per_epoch = dataset.get_dataset_size()
|
|
|
|
|
|
|
|
|
|
|
|
# network
|
|
|
|
# network
|
|
|
|
net_with_loss = NASNetAMobileWithLoss(cfg)
|
|
|
|
net = NASNetAMobile(cfg.num_classes)
|
|
|
|
if args_opt.resume:
|
|
|
|
if args_opt.resume:
|
|
|
|
ckpt = load_checkpoint(args_opt.resume)
|
|
|
|
ckpt = load_checkpoint(args_opt.resume)
|
|
|
|
load_param_into_net(net_with_loss, ckpt)
|
|
|
|
load_param_into_net(net, ckpt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#loss
|
|
|
|
|
|
|
|
loss = CrossEntropy(smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes, factor=cfg.aux_factor)
|
|
|
|
|
|
|
|
|
|
|
|
# learning rate schedule
|
|
|
|
# learning rate schedule
|
|
|
|
lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate,
|
|
|
|
lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate,
|
|
|
@ -82,20 +85,18 @@ if __name__ == '__main__':
|
|
|
|
# optimizer
|
|
|
|
# optimizer
|
|
|
|
decayed_params = []
|
|
|
|
decayed_params = []
|
|
|
|
no_decayed_params = []
|
|
|
|
no_decayed_params = []
|
|
|
|
for param in net_with_loss.trainable_params():
|
|
|
|
for param in net.trainable_params():
|
|
|
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
|
|
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
|
|
|
decayed_params.append(param)
|
|
|
|
decayed_params.append(param)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
no_decayed_params.append(param)
|
|
|
|
no_decayed_params.append(param)
|
|
|
|
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
|
|
|
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
|
|
|
{'params': no_decayed_params},
|
|
|
|
{'params': no_decayed_params},
|
|
|
|
{'order_params': net_with_loss.trainable_params()}]
|
|
|
|
{'order_params': net.trainable_params()}]
|
|
|
|
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
|
|
|
|
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
|
|
|
|
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
|
|
|
|
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
|
|
|
|
|
|
|
|
|
|
|
|
net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=optimizer)
|
|
|
|
net_with_grads.set_train()
|
|
|
|
|
|
|
|
model = Model(net_with_grads)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("============== Starting Training ==============")
|
|
|
|
print("============== Starting Training ==============")
|
|
|
|
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
|
|
|
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
|
|
|