|
|
|
@ -124,6 +124,25 @@ if __name__ == '__main__':
|
|
|
|
|
load_path = args_opt.pre_trained
|
|
|
|
|
if load_path != "":
|
|
|
|
|
param_dict = load_checkpoint(load_path)
|
|
|
|
|
|
|
|
|
|
key_mapping = {'down_sample_layer.1.beta': 'bn_down_sample.beta',
|
|
|
|
|
'down_sample_layer.1.gamma': 'bn_down_sample.gamma',
|
|
|
|
|
'down_sample_layer.0.weight': 'conv_down_sample.weight',
|
|
|
|
|
'down_sample_layer.1.moving_mean': 'bn_down_sample.moving_mean',
|
|
|
|
|
'down_sample_layer.1.moving_variance': 'bn_down_sample.moving_variance',
|
|
|
|
|
}
|
|
|
|
|
for oldkey in list(param_dict.keys()):
|
|
|
|
|
if not oldkey.startswith(('backbone', 'end_point', 'global_step', 'learning_rate', 'moments', 'momentum')):
|
|
|
|
|
data = param_dict.pop(oldkey)
|
|
|
|
|
newkey = 'backbone.' + oldkey
|
|
|
|
|
param_dict[newkey] = data
|
|
|
|
|
oldkey = newkey
|
|
|
|
|
for k, v in key_mapping.items():
|
|
|
|
|
if k in oldkey:
|
|
|
|
|
newkey = oldkey.replace(k, v)
|
|
|
|
|
param_dict[newkey] = param_dict.pop(oldkey)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
for item in list(param_dict.keys()):
|
|
|
|
|
if not item.startswith('backbone'):
|
|
|
|
|
param_dict.pop(item)
|
|
|
|
|