|
|
|
@ -108,13 +108,15 @@ if __name__ == '__main__':
|
|
|
|
|
load_path = args_opt.pre_trained
|
|
|
|
|
if load_path != "":
|
|
|
|
|
param_dict = load_checkpoint(load_path)
|
|
|
|
|
for item in list(param_dict.keys()):
|
|
|
|
|
if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
|
|
|
|
|
param_dict.pop(item)
|
|
|
|
|
if config.pretrain_epoch_size == 0:
|
|
|
|
|
for item in list(param_dict.keys()):
|
|
|
|
|
if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
|
|
|
|
|
param_dict.pop(item)
|
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
|
|
|
|
|
|
loss = LossNet()
|
|
|
|
|
lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32)
|
|
|
|
|
lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size),
|
|
|
|
|
mstype.float32)
|
|
|
|
|
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
|
|
|
|
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
|
|
|
|
|
|
|
|
|