!8635 Add SqueezeNet to model zoo

From: @GreyZzzzzzXh
Reviewed-by: 
Signed-off-by:
pull/8635/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 22528d117a

File diff suppressed because it is too large Load Diff

@ -0,0 +1,95 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval squeezenet."""
import os
import argparse
from mindspore import context
from mindspore.common import set_seed
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
help='Model.')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
set_seed(1)
if args_opt.net == "squeezenet":
from src.squeezenet import SqueezeNet as squeezenet
if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset_cifar as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset_imagenet as create_dataset
else:
from src.squeezenet import SqueezeNet_Residual as squeezenet
if args_opt.dataset == "cifar10":
from src.config import config3 as config
from src.dataset import create_dataset_cifar as create_dataset
else:
from src.config import config4 as config
from src.dataset import create_dataset_imagenet as create_dataset
if __name__ == '__main__':
target = args_opt.device_target
# init context
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
device_id=device_id)
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
batch_size=config.batch_size,
target=target)
step_size = dataset.get_dataset_size()
# define net
net = squeezenet(num_classes=config.class_num)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss
if args_opt.dataset == "imagenet":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True,
reduction='mean',
smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define model
model = Model(net,
loss_fn=loss,
metrics={'top_1_accuracy', 'top_5_accuracy'})
# eval model
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air and onnx models#################
python export.py --net squeezenet --dataset cifar10 --checkpoint_path squeezenet_cifar10-120_1562.ckpt
"""
import argparse
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
help='Model.')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
args_opt = parser.parse_args()
if args_opt.net == "squeezenet":
from src.squeezenet import SqueezeNet as squeezenet
else:
from src.squeezenet import SqueezeNet_Residual as squeezenet
if args_opt.dataset == "cifar10":
num_classes = 10
else:
num_classes = 1000
onnx_filename = args_opt.net + '_' + args_opt.dataset + '.onnx'
air_filename = args_opt.net + '_' + args_opt.dataset + '.air'
net = squeezenet(num_classes=num_classes)
assert args_opt.checkpoint_path is not None, "checkpoint_path is None."
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([1, 3, 227, 227], np.float32))
export(net, input_arr, file_name=onnx_filename, file_format="ONNX")
export(net, input_arr, file_name=air_filename, file_format="AIR")

@ -0,0 +1,99 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh scripts/run_distribute_train.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
then
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $3)
PATH2=$(get_real_path $4)
if [ $# == 5 ]
then
PATH3=$(get_real_path $5)
fi
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
if [ $# == 5 ] && [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=${i}
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ./train.py ./train_parallel$i
cp -r ./src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
if [ $# == 4 ]
then
python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
fi
if [ $# == 5 ]
then
python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
fi
cd ..
done

@ -0,0 +1,85 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh scripts/run_distribute_train_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
then
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $3)
if [ $# == 4 ]
then
PATH2=$(get_real_path $4)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ $# == 5 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
rm -rf ./train_parallel
mkdir ./train_parallel
cp ./train.py ./train_parallel
cp -r ./src ./train_parallel
cd ./train_parallel || exit
if [ $# == 3 ]
then
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --net=$1 --dataset=$2 --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log &
fi
if [ $# == 4 ]
then
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --net=$1 --dataset=$2 --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
fi

@ -0,0 +1,76 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ]
then
echo "Usage: sh scripts/run_eval.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
then
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $4)
PATH2=$(get_real_path $5)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=$3
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ./eval.py ./eval
cp -r ./src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

@ -0,0 +1,76 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ]
then
echo "Usage: sh scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
then
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $4)
PATH2=$(get_real_path $5)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=$3
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ./eval.py ./eval
cp -r ./src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
cd ..

@ -0,0 +1,87 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh scripts/run_standalone_train.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
then
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $4)
if [ $# == 5 ]
then
PATH2=$(get_real_path $5)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ $# == 5 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=$3
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ./train.py ./train
cp -r ./src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 4 ]
then
python train.py --net=$1 --dataset=$2 --dataset_path=$PATH1 &> log &
fi
if [ $# == 5 ]
then
python train.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
fi
cd ..

@ -0,0 +1,87 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh scripts/run_standalone_train_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "squeezenet" ] && [ $1 != "squeezenet_residual" ]
then
echo "error: the selected net is neither squeezenet nor squeezenet_residual"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $4)
if [ $# == 5 ]
then
PATH2=$(get_real_path $5)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ $# == 5 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=$3
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ./train.py ./train
cp -r ./src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 4 ]
then
python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 &> log &
fi
if [ $# == 5 ]
then
python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
fi
cd ..

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

@ -0,0 +1,102 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
# config for squeezenet, cifar10
config1 = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 120,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "poly",
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})
# config for squeezenet, imagenet
config2 = ed({
"class_num": 1000,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 7e-5,
"epoch_size": 200,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "poly",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})
# config for squeezenet_residual, cifar10
config3 = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 150,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "linear",
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})
# config for squeezenet_residual, imagenet
config4 = ed({
"class_num": 1000,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 7e-5,
"epoch_size": 300,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})

@ -0,0 +1,191 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset_cifar(dataset_path,
do_train,
repeat_num=1,
batch_size=32,
target="Ascend"):
"""
create a train or evaluate cifar10 dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
init()
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path,
num_parallel_workers=8,
shuffle=True)
else:
ds = de.Cifar10Dataset(dataset_path,
num_parallel_workers=8,
shuffle=True,
num_shards=device_num,
shard_id=rank_id)
# define map operations
if do_train:
trans = [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
C.Resize((227, 227)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.CutOut(112),
C.HWC2CHW()
]
else:
trans = [
C.Resize((227, 227)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op,
input_columns="label",
num_parallel_workers=8)
ds = ds.map(operations=trans,
input_columns="image",
num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def create_dataset_imagenet(dataset_path,
do_train,
repeat_num=1,
batch_size=32,
target="Ascend"):
"""
create a train or eval imagenet dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
init()
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path,
num_parallel_workers=8,
shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path,
num_parallel_workers=8,
shuffle=True,
num_shards=device_num,
shard_id=rank_id)
image_size = 227
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
C.Normalize(mean=mean, std=std),
C.CutOut(112),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize((256, 256)),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op,
input_columns="label",
num_parallel_workers=8)
ds = ds.map(operations=trans,
input_columns="image",
num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

@ -0,0 +1,106 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(lr_init, lr_end, lr_max, total_epochs, warmup_epochs,
pretrain_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
total_epochs(int): total epoch of training
warmup_epochs(int): number of warmup epochs
pretrain_epochs(int): number of pretrain epochs
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode,
including steps, poly, linear or cosine
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
pretrain_steps = steps_per_epoch * pretrain_epochs
decay_steps = total_steps - warmup_steps
if lr_decay_mode == 'steps':
decay_epoch_index = [
0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps
]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i, warmup_steps, lr_max, lr_init)
else:
base = (1.0 - (i - warmup_steps) / decay_steps)
lr = lr_max * base * base
lr_each_step.append(lr)
elif lr_decay_mode == 'linear':
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i, warmup_steps, lr_max, lr_init)
else:
lr = lr_max - (lr_max - lr_end) * (i -
warmup_steps) / decay_steps
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i, warmup_steps, lr_max, lr_init)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (
1 + math.cos(math.pi * 2 * 0.47 *
(i - warmup_steps) / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
raise NotImplementedError(
'Learning rate decay mode [{:s}] cannot be recognized'.format(
lr_decay_mode))
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[pretrain_steps:]
return learning_rate
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (base_lr - init_lr) / warmup_steps
lr = init_lr + lr_inc * current_step
return lr

@ -0,0 +1,216 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Squeezenet."""
import mindspore.nn as nn
from mindspore.common import initializer as weight_init
from mindspore.ops import operations as P
class Fire(nn.Cell):
def __init__(self, inplanes, squeeze_planes, expand1x1_planes,
expand3x3_planes):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes,
squeeze_planes,
kernel_size=1,
has_bias=True)
self.squeeze_activation = nn.ReLU()
self.expand1x1 = nn.Conv2d(squeeze_planes,
expand1x1_planes,
kernel_size=1,
has_bias=True)
self.expand1x1_activation = nn.ReLU()
self.expand3x3 = nn.Conv2d(squeeze_planes,
expand3x3_planes,
kernel_size=3,
pad_mode='same',
has_bias=True)
self.expand3x3_activation = nn.ReLU()
self.concat = P.Concat(axis=1)
def construct(self, x):
x = self.squeeze_activation(self.squeeze(x))
return self.concat((self.expand1x1_activation(self.expand1x1(x)),
self.expand3x3_activation(self.expand3x3(x))))
class SqueezeNet(nn.Cell):
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Get SqueezeNet neural network.
Args:
num_classes (int): Class number.
Returns:
Cell, cell instance of SqueezeNet neural network.
Examples:
>>> net = SqueezeNet(10)
"""
def __init__(self, num_classes=10):
super(SqueezeNet, self).__init__()
self.features = nn.SequentialCell([
nn.Conv2d(3,
96,
kernel_size=7,
stride=2,
pad_mode='valid',
has_bias=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(512, 64, 256, 256),
])
# Final convolution is initialized differently from the rest
self.final_conv = nn.Conv2d(512,
num_classes,
kernel_size=1,
has_bias=True)
self.dropout = nn.Dropout(keep_prob=0.5)
self.relu = nn.ReLU()
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.custom_init_weight()
def custom_init_weight(self):
"""
Init the weight of Conv2d in the net.
"""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
if cell is self.final_conv:
cell.weight.set_data(
weight_init.initializer('normal', cell.weight.shape,
cell.weight.dtype))
else:
cell.weight.set_data(
weight_init.initializer('he_uniform',
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(
weight_init.initializer('zeros', cell.bias.shape,
cell.bias.dtype))
def construct(self, x):
x = self.features(x)
x = self.dropout(x)
x = self.final_conv(x)
x = self.relu(x)
x = self.mean(x, (2, 3))
x = self.flatten(x)
return x
class SqueezeNet_Residual(nn.Cell):
r"""SqueezeNet with simple bypass model architecture from the `"SqueezeNet:
AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Get SqueezeNet with simple bypass neural network.
Args:
num_classes (int): Class number.
Returns:
Cell, cell instance of SqueezeNet with simple bypass neural network.
Examples:
>>> net = SqueezeNet_Residual(10)
"""
def __init__(self, num_classes=10):
super(SqueezeNet_Residual, self).__init__()
self.conv1 = nn.Conv2d(3,
96,
kernel_size=7,
stride=2,
pad_mode='valid',
has_bias=True)
self.fire2 = Fire(96, 16, 64, 64)
self.fire3 = Fire(128, 16, 64, 64)
self.fire4 = Fire(128, 32, 128, 128)
self.fire5 = Fire(256, 32, 128, 128)
self.fire6 = Fire(256, 48, 192, 192)
self.fire7 = Fire(384, 48, 192, 192)
self.fire8 = Fire(384, 64, 256, 256)
self.fire9 = Fire(512, 64, 256, 256)
# Final convolution is initialized differently from the rest
self.conv10 = nn.Conv2d(512, num_classes, kernel_size=1, has_bias=True)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
self.add = P.TensorAdd()
self.dropout = nn.Dropout(keep_prob=0.5)
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.custom_init_weight()
def custom_init_weight(self):
"""
Init the weight of Conv2d in the net.
"""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
if cell is self.conv10:
cell.weight.set_data(
weight_init.initializer('normal', cell.weight.shape,
cell.weight.dtype))
else:
cell.weight.set_data(
weight_init.initializer('xavier_uniform',
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(
weight_init.initializer('zeros', cell.bias.shape,
cell.bias.dtype))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.fire2(x)
x = self.add(x, self.fire3(x))
x = self.fire4(x)
x = self.max_pool2d(x)
x = self.add(x, self.fire5(x))
x = self.fire6(x)
x = self.add(x, self.fire7(x))
x = self.fire8(x)
x = self.max_pool2d(x)
x = self.add(x, self.fire9(x))
x = self.dropout(x)
x = self.conv10(x)
x = self.relu(x)
x = self.mean(x, (2, 3))
x = self.flatten(x)
return x

@ -0,0 +1,169 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train squeezenet."""
import os
import argparse
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
help='Model.')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
args_opt = parser.parse_args()
set_seed(1)
if args_opt.net == "squeezenet":
from src.squeezenet import SqueezeNet as squeezenet
if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset_cifar as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset_imagenet as create_dataset
else:
from src.squeezenet import SqueezeNet_Residual as squeezenet
if args_opt.dataset == "cifar10":
from src.config import config3 as config
from src.dataset import create_dataset_cifar as create_dataset
else:
from src.config import config4 as config
from src.dataset import create_dataset_imagenet as create_dataset
if __name__ == '__main__':
target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
# init context
context.set_context(mode=context.GRAPH_MODE,
device_target=target)
if args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id,
enable_auto_mixed_precision=True)
context.set_auto_parallel_context(
device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
# GPU target
else:
init()
context.set_auto_parallel_context(
device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
repeat_num=1,
batch_size=config.batch_size,
target=target)
step_size = dataset.get_dataset_size()
# define net
net = squeezenet(num_classes=config.class_num)
# load checkpoint
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
# init lr
lr = get_lr(lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
total_epochs=config.epoch_size,
warmup_epochs=config.warmup_epochs,
pretrain_epochs=config.pretrain_epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
lr = Tensor(lr)
# define loss
if args_opt.dataset == "imagenet":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True,
reduction='mean',
smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define opt, model
if target == "Ascend":
loss_scale = FixedLossScaleManager(config.loss_scale,
drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
lr,
config.momentum,
config.weight_decay,
config.loss_scale,
use_nesterov=True)
model = Model(net,
loss_fn=loss,
optimizer=opt,
loss_scale_manager=loss_scale,
metrics={'acc'},
amp_level="O2",
keep_batchnorm_fp32=False)
else:
# GPU target
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
lr,
config.momentum,
config.weight_decay,
use_nesterov=True)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(
save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=args_opt.net + '_' + args_opt.dataset,
directory=ckpt_save_dir,
config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size - config.pretrain_epoch_size,
dataset,
callbacks=cb)
Loading…
Cancel
Save