add filter weight when fine-tune in mobilenetv2

pull/9927/head
zhaoting 4 years ago
parent e37591ded5
commit bed1859921

@ -4,17 +4,17 @@
- [Model Architecture](#model-architecture) - [Model Architecture](#model-architecture)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Features](#features) - [Features](#features)
- [Mixed Precision](#mixed-precision(ascend)) - [Mixed Precision](#mixed-precision(ascend))
- [Environment Requirements](#environment-requirements) - [Environment Requirements](#environment-requirements)
- [Script Description](#script-description) - [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code) - [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process) - [Training Process](#training-process)
- [Evaluation Process](#eval-process) - [Evaluation Process](#eval-process)
- [Export MindIR](#export-mindir) - [Export MindIR](#export-mindir)
- [Model Description](#model-description) - [Model Description](#model-description)
- [Performance](#performance) - [Performance](#performance)
- [Training Performance](#training-performance) - [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance) - [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation) - [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)
@ -35,10 +35,10 @@ The overall network architecture of MobileNetV2 is show below:
Dataset used: [imagenet](http://www.image-net.org/) Dataset used: [imagenet](http://www.image-net.org/)
- Dataset size: ~125G, 1.2W colorful images in 1000 classes - Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images - Train: 120G, 1.2W images
- Test: 5G, 50000 images - Test: 5G, 50000 images
- Data format: RGB images. - Data format: RGB images.
- Note: Data will be processed in src/dataset.py - Note: Data will be processed in src/dataset.py
# [Features](#contents) # [Features](#contents)
@ -50,12 +50,12 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
# [Environment Requirements](#contents) # [Environment Requirements](#contents)
- HardwareAscend/GPU/CPU - HardwareAscend/GPU/CPU
- Prepare hardware environment with Ascend, GPU or CPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Prepare hardware environment with Ascend, GPU or CPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en) - [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below - For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents) # [Script description](#contents)
@ -87,9 +87,11 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
You can start training using python or shell scripts. The usage of shell scripts as follows: You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
- CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
`CKPT_PATH` `FREEZE_LAYER` and `FILTER_HEAD` are optional, when set `CKPT_PATH`, `FREEZE_LAYER` must be set. `FREEZE_LAYER` should be in ["none", "backbone"], and if you set `FREEZE_LAYER`="backbone", the parameter in backbone will be freezed when training and the parameter in head will not be load from checkpoint. if `FILTER_HEAD`=True, the parameter in head will not be load from checkpoint.
> RANK_TABLE_FILE is HCCL configuration file when running on Ascend. > RANK_TABLE_FILE is HCCL configuration file when running on Ascend.
> The common restrictions on using the distributed service are as follows. For details, see the HCCL documentation. > The common restrictions on using the distributed service are as follows. For details, see the HCCL documentation.
@ -113,14 +115,14 @@ You can start training using python or shell scripts. The usage of shell scripts
# fine tune whole network example # fine tune whole network example
python: python:
Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
shell: shell:
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none True
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] none GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] none True
CPU: sh run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] none CPU: sh run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] none True
# fine tune full connected layers example # fine tune full connected layers example
python: python:
@ -184,7 +186,7 @@ result: {'acc': 0.71976314102564111} ckpt=./ckpt_0/mobilenet-200_625.ckpt
Change the export mode and export file in `src/config.py`, and run `export.py`. Change the export mode and export file in `src/config.py`, and run `export.py`.
``` ```shell
python export.py --platform [PLATFORM] --pretrain_ckpt [CKPT_PATH] python export.py --platform [PLATFORM] --pretrain_ckpt [CKPT_PATH]
``` ```

@ -1,5 +1,4 @@
# 目录 # 目录
<!-- TOC -->
- [目录](#目录) - [目录](#目录)
- [MobileNetV2描述](#mobilenetv2描述) - [MobileNetV2描述](#mobilenetv2描述)
@ -25,8 +24,6 @@
- [随机情况说明](#随机情况说明) - [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页) - [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# MobileNetV2描述 # MobileNetV2描述
MobileNetV2结合硬件感知神经网络架构搜索NAS和NetAdapt算法已经可以移植到手机CPU上运行后续随新架构进一步优化改进。2019年11月20日 MobileNetV2结合硬件感知神经网络架构搜索NAS和NetAdapt算法已经可以移植到手机CPU上运行后续随新架构进一步优化改进。2019年11月20日
@ -44,10 +41,10 @@ MobileNetV2总体网络架构如下
使用的数据集:[imagenet](http://www.image-net.org/) 使用的数据集:[imagenet](http://www.image-net.org/)
- 数据集大小125G共1000个类、1.2万张彩色图像 - 数据集大小125G共1000个类、1.2万张彩色图像
- 训练集120G共1.2万张图像 - 训练集120G共1.2万张图像
- 测试集5G共5万张图像 - 测试集5G共5万张图像
- 数据格式RGB - 数据格式RGB
- 注数据在src/dataset.py中处理。 - 注数据在src/dataset.py中处理。
# 特性 # 特性
@ -59,12 +56,12 @@ MobileNetV2总体网络架构如下
# 环境要求 # 环境要求
- 硬件Ascend/GPU/CPU - 硬件Ascend/GPU/CPU
- 使用Ascend、GPU或CPU处理器来搭建硬件环境。如需试用Ascend处理器请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com审核通过即可获得资源。 - 使用Ascend、GPU或CPU处理器来搭建硬件环境。如需试用Ascend处理器请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com审核通过即可获得资源。
- 框架 - 框架
- [MindSpore](https://www.mindspore.cn/install/en) - [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源: - 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 脚本说明 # 脚本说明
@ -96,9 +93,11 @@ MobileNetV2总体网络架构如下
使用python或shell脚本开始训练。shell脚本的使用方法如下 使用python或shell脚本开始训练。shell脚本的使用方法如下
- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
- CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] [FILTER_HEAD]
`CKPT_PATH` `FREEZE_LAYER``FILTER_HEAD` 是可选择的选项, 如果设置`CKPT_PATH`, `FREEZE_LAYER` 也必须同时设置. `FREEZE_LAYER` 可以是 ["none", "backbone"], 如果设置 `FREEZE_LAYER`="backbone", 训练过程中backbone中的参数会被冻结同时不会从checkpoint中加载head部分的参数. 如果`FILTER_HEAD`=True, 不会从checkpoint中加载head部分的参数.
> RANK_TABLE_FILE 是在Ascned上运行分布式任务时HCCL的配置文件 > RANK_TABLE_FILE 是在Ascned上运行分布式任务时HCCL的配置文件
> 我们列出使用分布式服务常见的使用限制详细的可以查看HCCL对应的使用文档。 > 我们列出使用分布式服务常见的使用限制详细的可以查看HCCL对应的使用文档。
@ -122,14 +121,14 @@ MobileNetV2总体网络架构如下
# 全网微调示例 # 全网微调示例
python: python:
Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none GPU: python train.py --platform GPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none CPU: python train.py --platform CPU --dataset_path [TRAIN_DATASET_PATH] --pretrain_ckpt [CKPT_PATH] --freeze_layer none --filter_head True
shell: shell:
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json [TRAIN_DATASET_PATH] [CKPT_PATH] none True
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] none GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 [TRAIN_DATASET_PATH] [CKPT_PATH] none True
CPU: sh run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] none CPU: sh run_train.sh CPU [TRAIN_DATASET_PATH] [CKPT_PATH] none True
# 全连接层微调示例 # 全连接层微调示例
python: python:
@ -193,7 +192,7 @@ result:{'acc':0.71976314102564111} ckpt=./ckpt_0/mobilenet-200_625.ckpt
修改`src/config.py`文件中的`export_mode`和`export_file`, 运行`export.py`。 修改`src/config.py`文件中的`export_mode`和`export_file`, 运行`export.py`。
``` ```shell
python export.py --platform [PLATFORM] --pretrain_ckpt [CKPT_PATH] python export.py --platform [PLATFORM] --pretrain_ckpt [CKPT_PATH]
``` ```

@ -16,6 +16,25 @@
run_ascend() run_ascend()
{ {
if [ $# = 5 ] ; then
PRETRAINED_CKPT=""
FREEZE_LAYER="none"
FILTER_HEAD="False"
elif [ $# = 7 ] ; then
PRETRAINED_CKPT=$6
FREEZE_LAYER=$7
FILTER_HEAD="False"
elif [ $# = 8 ] ; then
PRETRAINED_CKPT=$6
FREEZE_LAYER=$7
FILTER_HEAD=$8
else
echo "Usage:
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH](optional) [FREEZE_LAYER](optional) [FILTER_HEAD](optional)
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi;
if [ $2 -lt 1 ] && [ $2 -gt 8 ] if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then then
echo "error: DEVICE_NUM=$2 is not in (1-8)" echo "error: DEVICE_NUM=$2 is not in (1-8)"
@ -59,8 +78,9 @@ run_ascend()
python train.py \ python train.py \
--platform=$1 \ --platform=$1 \
--dataset_path=$5 \ --dataset_path=$5 \
--pretrain_ckpt=$6 \ --pretrain_ckpt=$PRETRAINED_CKPT \
--freeze_layer=$7 \ --freeze_layer=$FREEZE_LAYER \
--filter_head=$FILTER_HEAD \
&> log$i.log & &> log$i.log &
cd .. cd ..
done done
@ -68,6 +88,24 @@ run_ascend()
run_gpu() run_gpu()
{ {
if [ $# = 4 ] ; then
PRETRAINED_CKPT=""
FREEZE_LAYER="none"
FILTER_HEAD="False"
elif [ $# = 6 ] ; then
PRETRAINED_CKPT=$5
FREEZE_LAYER=$6
FILTER_HEAD="False"
elif [ $# = 7 ] ; then
PRETRAINED_CKPT=$5
FREEZE_LAYER=$6
FILTER_HEAD=$7
else
echo "Usage:
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH](optional) [FREEZE_LAYER](optional) [FILTER_HEAD](optional)
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi;
if [ $2 -lt 1 ] && [ $2 -gt 8 ] if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then then
echo "error: DEVICE_NUM=$2 is not in (1-8)" echo "error: DEVICE_NUM=$2 is not in (1-8)"
@ -94,14 +132,32 @@ run_gpu()
python ${BASEPATH}/../train.py \ python ${BASEPATH}/../train.py \
--platform=$1 \ --platform=$1 \
--dataset_path=$4 \ --dataset_path=$4 \
--pretrain_ckpt=$5 \ --pretrain_ckpt=$PRETRAINED_CKPT \
--freeze_layer=$6 \ --freeze_layer=$FREEZE_LAYER \
--filter_head=$FILTER_HEAD \
&> ../train.log & # dataset train folder &> ../train.log & # dataset train folder
} }
run_cpu() run_cpu()
{ {
if [ $# = 2 ] ; then
PRETRAINED_CKPT=""
FREEZE_LAYER="none"
FILTER_HEAD="False"
elif [ $# = 4 ] ; then
PRETRAINED_CKPT=$3
FREEZE_LAYER=$4
FILTER_HEAD="False"
elif [ $# = 5 ] ; then
PRETRAINED_CKPT=$3
FREEZE_LAYER=$4
FILTER_HEAD=$5
else
echo "Usage:
CPU: sh run_train.sh CPU [DATASET_PATH]
CPU: sh run_train.sh CPU [DATASET_PATH] [CKPT_PATH](optional) [FREEZE_LAYER](optional) [FILTER_HEAD](optional)"
exit 1
fi;
if [ ! -d $2 ] if [ ! -d $2 ]
then then
echo "error: DATASET_PATH=$2 is not a directory" echo "error: DATASET_PATH=$2 is not a directory"
@ -120,22 +176,12 @@ run_cpu()
python ${BASEPATH}/../train.py \ python ${BASEPATH}/../train.py \
--platform=$1 \ --platform=$1 \
--dataset_path=$2 \ --dataset_path=$2 \
--pretrain_ckpt=$3 \ --pretrain_ckpt=$PRETRAINED_CKPT \
--freeze_layer=$4 \ --freeze_layer=$FREEZE_LAYER \
--filter_head=$FILTER_HEAD \
&> ../train.log & # dataset train folder &> ../train.log & # dataset train folder
} }
if [ $# -gt 7 ] || [ $# -lt 4 ]
then
echo "Usage:
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH]
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
CPU: sh run_train.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]"
exit 1
fi
if [ $1 = "Ascend" ] ; then if [ $1 = "Ascend" ] ; then
run_ascend "$@" run_ascend "$@"
elif [ $1 = "GPU" ] ; then elif [ $1 = "GPU" ] ; then

@ -26,6 +26,8 @@ def train_parse_args():
train_parser.add_argument('--freeze_layer', type=str, default="", choices=["", "none", "backbone"], \ train_parser.add_argument('--freeze_layer', type=str, default="", choices=["", "none", "backbone"], \
help="freeze the weights of network from start to which layers") help="freeze the weights of network from start to which layers")
train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
train_parser.add_argument('--filter_head', type=ast.literal_eval, default=False,\
help='Filter head weight parameters when load checkpoint, default is False.')
train_args = train_parser.parse_args() train_args = train_parser.parse_args()
train_args.is_training = True train_args.is_training = True
if train_args.platform == "CPU": if train_args.platform == "CPU":

@ -109,18 +109,10 @@ class Monitor(Callback):
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
def load_ckpt(network, pretrain_ckpt_path, trainable=True): def load_ckpt(network, pretrain_ckpt_path, trainable=True):
""" """load checkpoint into network."""
incremental_learning or not
"""
param_dict = load_checkpoint(pretrain_ckpt_path) param_dict = load_checkpoint(pretrain_ckpt_path)
if hasattr(network, "head"):
head_param = network.head.parameters_dict()
for k, v in head_param.items():
if param_dict[k].shape != v.shape:
param_dict.pop(k)
param_dict.pop(f"moments.{k}")
print(f"Filter {k} don't load weights from checkpoint.")
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
if not trainable: if not trainable:
for param in network.get_parameters(): for param in network.get_parameters():

@ -59,6 +59,8 @@ if __name__ == '__main__':
if args_opt.freeze_layer == "backbone": if args_opt.freeze_layer == "backbone":
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
step_size = extract_features(backbone_net, args_opt.dataset_path, config) step_size = extract_features(backbone_net, args_opt.dataset_path, config)
elif args_opt.filter_head:
load_ckpt(backbone_net, args_opt.pretrain_ckpt)
else: else:
load_ckpt(net, args_opt.pretrain_ckpt) load_ckpt(net, args_opt.pretrain_ckpt)
if step_size == 0: if step_size == 0:

Loading…
Cancel
Save