diff --git a/model_zoo/official/cv/faster_rcnn/train.py b/model_zoo/official/cv/faster_rcnn/train.py index 5ecba7eb1d..f07a5e7f2e 100644 --- a/model_zoo/official/cv/faster_rcnn/train.py +++ b/model_zoo/official/cv/faster_rcnn/train.py @@ -124,6 +124,25 @@ if __name__ == '__main__': load_path = args_opt.pre_trained if load_path != "": param_dict = load_checkpoint(load_path) + + key_mapping = {'down_sample_layer.1.beta': 'bn_down_sample.beta', + 'down_sample_layer.1.gamma': 'bn_down_sample.gamma', + 'down_sample_layer.0.weight': 'conv_down_sample.weight', + 'down_sample_layer.1.moving_mean': 'bn_down_sample.moving_mean', + 'down_sample_layer.1.moving_variance': 'bn_down_sample.moving_variance', + } + for oldkey in list(param_dict.keys()): + if not oldkey.startswith(('backbone', 'end_point', 'global_step', 'learning_rate', 'moments', 'momentum')): + data = param_dict.pop(oldkey) + newkey = 'backbone.' + oldkey + param_dict[newkey] = data + oldkey = newkey + for k, v in key_mapping.items(): + if k in oldkey: + newkey = oldkey.replace(k, v) + param_dict[newkey] = param_dict.pop(oldkey) + break + for item in list(param_dict.keys()): if not item.startswith('backbone'): param_dict.pop(item)