From 94051dea6972aad97c3853c5dafb287b650c6546 Mon Sep 17 00:00:00 2001 From: hanjun996 Date: Wed, 5 Aug 2020 10:49:24 +0800 Subject: [PATCH] add filter_wieght --- model_zoo/official/cv/ssd/eval.py | 2 ++ model_zoo/official/cv/ssd/train.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py index 1ce66e38f1..37b5092206 100644 --- a/model_zoo/official/cv/ssd/eval.py +++ b/model_zoo/official/cv/ssd/eval.py @@ -78,6 +78,8 @@ if __name__ == '__main__': prefix = "ssd_eval.mindrecord" mindrecord_dir = config.mindrecord_dir mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + if args_opt.dataset == "voc": + config.coco_root = config.voc_root if not os.path.exists(mindrecord_file): if not os.path.isdir(mindrecord_dir): os.makedirs(mindrecord_dir) diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index a1739d0e4e..c38026103d 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -46,6 +46,7 @@ def main(): parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") + parser.add_argument("--filter_weight", type=bool, default=False, help="Filter weight parameters, default is False.") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) @@ -117,7 +118,8 @@ def main(): if args_opt.pre_trained_epoch_size <= 0: raise KeyError("pre_trained_epoch_size must be greater than 0.") param_dict = load_checkpoint(args_opt.pre_trained) - filter_checkpoint_parameter(param_dict) + if args_opt.filter_weight: + filter_checkpoint_parameter(param_dict) load_param_into_net(net, param_dict) lr = Tensor(get_lr(global_step=config.global_step,