!11029 FCN8s
From: @zhu_wenyong Reviewed-by: @linqingke,@oacjiewen Signed-off-by: @linqingkepull/11029/MERGE
commit
d4ef0452a6
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,213 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval FCN8s."""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.nets.FCN8s import FCN8s
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('mindspore FCN8s eval')
|
||||
|
||||
# val data
|
||||
parser.add_argument('--data_root', type=str, default='../VOCdevkit/VOC2012/', help='root path of val data')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
|
||||
parser.add_argument('--data_lst', type=str, default='../VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt',
|
||||
help='list of val data')
|
||||
parser.add_argument('--crop_size', type=int, default=512, help='crop size')
|
||||
parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
|
||||
parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
|
||||
parser.add_argument('--scales', type=float, default=[1.0], action='append', help='scales of evaluation')
|
||||
parser.add_argument('--flip', type=bool, default=False, help='perform left-right flip')
|
||||
parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
|
||||
parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
|
||||
|
||||
# model
|
||||
parser.add_argument('--model', type=str, default='FCN8s', help='select model')
|
||||
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
|
||||
parser.add_argument('--ckpt_path', type=str, default='model_new/FCN8s-500_82.ckpt', help='model to evaluate')
|
||||
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
return args
|
||||
|
||||
|
||||
def cal_hist(a, b, n):
|
||||
k = (a >= 0) & (a < n)
|
||||
return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n)
|
||||
|
||||
|
||||
def resize_long(img, long_size=513):
|
||||
h, w, _ = img.shape
|
||||
if h > w:
|
||||
new_h = long_size
|
||||
new_w = int(1.0 * long_size * w / h)
|
||||
else:
|
||||
new_w = long_size
|
||||
new_h = int(1.0 * long_size * h / w)
|
||||
imo = cv2.resize(img, (new_w, new_h))
|
||||
return imo
|
||||
|
||||
|
||||
class BuildEvalNetwork(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(BuildEvalNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.softmax = nn.Softmax(axis=1)
|
||||
|
||||
def construct(self, input_data):
|
||||
output = self.network(input_data)
|
||||
output = self.softmax(output)
|
||||
return output
|
||||
|
||||
|
||||
def pre_process(args, img_, crop_size=512):
|
||||
# resize
|
||||
img_ = resize_long(img_, crop_size)
|
||||
resize_h, resize_w, _ = img_.shape
|
||||
|
||||
# mean, std
|
||||
image_mean = np.array(args.image_mean)
|
||||
image_std = np.array(args.image_std)
|
||||
img_ = (img_ - image_mean) / image_std
|
||||
|
||||
# pad to crop_size
|
||||
pad_h = crop_size - img_.shape[0]
|
||||
pad_w = crop_size - img_.shape[1]
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
# hwc to chw
|
||||
img_ = img_.transpose((2, 0, 1))
|
||||
return img_, resize_h, resize_w
|
||||
|
||||
|
||||
def eval_batch(args, eval_net, img_lst, crop_size=512, flip=True):
|
||||
result_lst = []
|
||||
batch_size = len(img_lst)
|
||||
batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32)
|
||||
resize_hw = []
|
||||
for l in range(batch_size):
|
||||
img_ = img_lst[l]
|
||||
img_, resize_h, resize_w = pre_process(args, img_, crop_size)
|
||||
batch_img[l] = img_
|
||||
resize_hw.append([resize_h, resize_w])
|
||||
|
||||
batch_img = np.ascontiguousarray(batch_img)
|
||||
net_out = eval_net(Tensor(batch_img, mstype.float32))
|
||||
net_out = net_out.asnumpy()
|
||||
|
||||
if flip:
|
||||
batch_img = batch_img[:, :, :, ::-1]
|
||||
net_out_flip = eval_net(Tensor(batch_img, mstype.float32))
|
||||
net_out += net_out_flip.asnumpy()[:, :, :, ::-1]
|
||||
|
||||
for bs in range(batch_size):
|
||||
probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0))
|
||||
ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1]
|
||||
probs_ = cv2.resize(probs_, (ori_w, ori_h))
|
||||
result_lst.append(probs_)
|
||||
|
||||
return result_lst
|
||||
|
||||
|
||||
def eval_batch_scales(args, eval_net, img_lst, scales,
|
||||
base_crop_size=512, flip=True):
|
||||
sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales]
|
||||
probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip)
|
||||
print(sizes_)
|
||||
for crop_size_ in sizes_[1:]:
|
||||
probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip)
|
||||
for pl, _ in enumerate(probs_lst):
|
||||
probs_lst[pl] += probs_lst_tmp[pl]
|
||||
|
||||
result_msk = []
|
||||
for i in probs_lst:
|
||||
result_msk.append(i.argmax(axis=2))
|
||||
return result_msk
|
||||
|
||||
|
||||
def net_eval():
|
||||
args = parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id,
|
||||
save_graphs=False)
|
||||
|
||||
# data list
|
||||
with open(args.data_lst) as f:
|
||||
img_lst = f.readlines()
|
||||
|
||||
net = FCN8s(n_class=args.num_classes)
|
||||
|
||||
# load model
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# evaluate
|
||||
hist = np.zeros((args.num_classes, args.num_classes))
|
||||
batch_img_lst = []
|
||||
batch_msk_lst = []
|
||||
bi = 0
|
||||
image_num = 0
|
||||
for i, line in enumerate(img_lst):
|
||||
|
||||
img_name = line.strip('\n')
|
||||
data_root = args.data_root
|
||||
img_path = data_root + '/JPEGImages/' + str(img_name) + '.jpg'
|
||||
msk_path = data_root + '/SegmentationClass/' + str(img_name) + '.png'
|
||||
|
||||
img_ = np.array(Image.open(img_path), dtype=np.uint8)
|
||||
msk_ = np.array(Image.open(msk_path), dtype=np.uint8)
|
||||
|
||||
batch_img_lst.append(img_)
|
||||
batch_msk_lst.append(msk_)
|
||||
bi += 1
|
||||
if bi == args.batch_size:
|
||||
batch_res = eval_batch_scales(args, net, batch_img_lst, scales=args.scales,
|
||||
base_crop_size=args.crop_size, flip=args.flip)
|
||||
for mi in range(args.batch_size):
|
||||
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
|
||||
|
||||
bi = 0
|
||||
batch_img_lst = []
|
||||
batch_msk_lst = []
|
||||
print('processed {} images'.format(i+1))
|
||||
image_num = i
|
||||
|
||||
if bi > 0:
|
||||
batch_res = eval_batch_scales(args, net, batch_img_lst, scales=args.scales,
|
||||
base_crop_size=args.crop_size, flip=args.flip)
|
||||
for mi in range(bi):
|
||||
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
|
||||
print('processed {} images'.format(image_num + 1))
|
||||
|
||||
print(hist)
|
||||
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
|
||||
print('per-class IoU', iu)
|
||||
print('mean IoU', np.nanmean(iu))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net_eval()
|
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
python src/data/build_seg_data.py --data_root=/home/sun/data/Mindspore/benchmark_RELEASE/dataset \
|
||||
--data_lst=/home/sun/data/Mindspore/benchmark_RELEASE/dataset/trainaug.txt \
|
||||
--dst_path=dataset/MINDRECORED_NAME.mindrecord \
|
||||
--num_shards=1 \
|
||||
--shuffle=True
|
@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_eval.sh DEVICE_NUM RANK_TABLE_FILE DATASET CKPT_PATH"
|
||||
echo "for example: sh run_eval.sh [RANK_TABLE_FILE] /path/to/dataset /path/to/ckpt device_id"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
|
||||
export DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
DEVICE_ID=$3
|
||||
|
||||
rm -rf eval
|
||||
mkdir ./eval
|
||||
cp ./*.py ./eval
|
||||
cp -r ./src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start testing"
|
||||
env > env.log
|
||||
python eval.py \
|
||||
--device_id=$DEVICE_ID \
|
||||
--data_path=$DATA_PATH \
|
||||
--ckpt_path=$CKPT_PATH #> log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
|
@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [device_num][RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $2 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=$1
|
||||
export RANK_SIZE=$1
|
||||
RANK_TABLE_FILE=$(realpath $2)
|
||||
export RANK_TABLE_FILE
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<$1; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i > log 2>&1 &
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
FCN8s_VOC2012_cfg = edict({
|
||||
# dataset
|
||||
'data_file': '/data/workspace/mindspore_dataset/FCN/FCN/dataset/MINDRECORED_NAME.mindrecord',
|
||||
'batch_size': 32,
|
||||
'crop_size': 512,
|
||||
'image_mean': [103.53, 116.28, 123.675],
|
||||
'image_std': [57.375, 57.120, 58.395],
|
||||
'min_scale': 0.5,
|
||||
'max_scale': 2.0,
|
||||
'ignore_label': 255,
|
||||
'num_classes': 21,
|
||||
|
||||
# optimizer
|
||||
'train_epochs': 500,
|
||||
'base_lr': 0.015,
|
||||
'loss_scale': 1024.0,
|
||||
|
||||
# model
|
||||
'model': 'FCN8s',
|
||||
'ckpt_vgg16': '/data/workspace/mindspore_dataset/FCN/FCN/model/0-150_5004.ckpt',
|
||||
'ckpt_pre_trained': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt',
|
||||
|
||||
# train
|
||||
'save_steps': 330,
|
||||
'keep_checkpoint_max': 500,
|
||||
'train_dir': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/',
|
||||
})
|
@ -0,0 +1,78 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
seg_schema = {"file_name": {"type": "string"}, "label": {"type": "bytes"}, "data": {"type": "bytes"}}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('mindrecord')
|
||||
|
||||
parser.add_argument('--data_root', type=str, default='', help='root path of data')
|
||||
parser.add_argument('--data_lst', type=str, default='', help='list of data')
|
||||
parser.add_argument('--dst_path', type=str, default='', help='save path of mindrecords')
|
||||
parser.add_argument('--num_shards', type=int, default=8, help='number of shards')
|
||||
parser.add_argument('--shuffle', type=bool, default=True, help='shuffle or not')
|
||||
|
||||
parser_args, _ = parser.parse_known_args()
|
||||
return parser_args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
datas = []
|
||||
with open(args.data_lst) as f:
|
||||
lines = f.readlines()
|
||||
if args.shuffle:
|
||||
np.random.shuffle(lines)
|
||||
|
||||
dst_dir = '/'.join(args.dst_path.split('/')[:-1])
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
|
||||
print('number of samples:', len(lines))
|
||||
writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards)
|
||||
writer.add_schema(seg_schema, "seg_schema")
|
||||
cnt = 0
|
||||
|
||||
for l in lines:
|
||||
img_name = l.strip('\n')
|
||||
|
||||
img_path = 'img/' + str(img_name) + '.jpg'
|
||||
label_path = 'cls_png/' + str(img_name) + '.png'
|
||||
|
||||
sample_ = {"file_name": img_path.split('/')[-1]}
|
||||
|
||||
with open(os.path.join(args.data_root, img_path), 'rb') as f:
|
||||
sample_['data'] = f.read()
|
||||
with open(os.path.join(args.data_root, label_path), 'rb') as f:
|
||||
sample_['label'] = f.read()
|
||||
datas.append(sample_)
|
||||
cnt += 1
|
||||
if cnt % 1000 == 0:
|
||||
writer.write_raw_data(datas)
|
||||
print('number of samples written:', cnt)
|
||||
datas = []
|
||||
|
||||
if datas:
|
||||
writer.write_raw_data(datas)
|
||||
writer.commit()
|
||||
print('number of samples written:', cnt)
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import mindspore.dataset as de
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
|
||||
class SegDataset:
|
||||
def __init__(self,
|
||||
image_mean,
|
||||
image_std,
|
||||
data_file='',
|
||||
batch_size=32,
|
||||
crop_size=512,
|
||||
max_scale=2.0,
|
||||
min_scale=0.5,
|
||||
ignore_label=255,
|
||||
num_classes=21,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=None,
|
||||
shard_num=None):
|
||||
|
||||
self.data_file = data_file
|
||||
self.batch_size = batch_size
|
||||
self.crop_size = crop_size
|
||||
self.image_mean = np.array(image_mean, dtype=np.float32)
|
||||
self.image_std = np.array(image_std, dtype=np.float32)
|
||||
self.max_scale = max_scale
|
||||
self.min_scale = min_scale
|
||||
self.ignore_label = ignore_label
|
||||
self.num_classes = num_classes
|
||||
self.num_readers = num_readers
|
||||
self.num_parallel_calls = num_parallel_calls
|
||||
self.shard_id = shard_id
|
||||
self.shard_num = shard_num
|
||||
assert max_scale > min_scale
|
||||
|
||||
def preprocess_(self, image, label):
|
||||
# bgr image
|
||||
image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
sc = np.random.uniform(self.min_scale, self.max_scale)
|
||||
new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
|
||||
image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
||||
label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
image_out = (image_out - self.image_mean) / self.image_std
|
||||
h_, w_ = max(new_h, self.crop_size), max(new_w, self.crop_size)
|
||||
pad_h, pad_w = h_ - new_h, w_ - new_w
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||
label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
|
||||
offset_h = np.random.randint(0, h_ - self.crop_size + 1)
|
||||
offset_w = np.random.randint(0, w_ - self.crop_size + 1)
|
||||
image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
|
||||
label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
|
||||
|
||||
if np.random.uniform(0.0, 1.0) > 0.5:
|
||||
image_out = image_out[:, ::-1, :]
|
||||
label_out = label_out[:, ::-1]
|
||||
|
||||
image_out = image_out.transpose((2, 0, 1))
|
||||
image_out = image_out.copy()
|
||||
label_out = label_out.copy()
|
||||
return image_out, label_out
|
||||
|
||||
def get_dataset(self, repeat=1):
|
||||
data_set = de.MindDataset(dataset_file=self.data_file, columns_list=["data", "label"],
|
||||
shuffle=True, num_parallel_workers=self.num_readers,
|
||||
num_shards=self.shard_num, shard_id=self.shard_id)
|
||||
transforms_list = self.preprocess_
|
||||
data_set = data_set.map(operations=transforms_list, input_columns=["data", "label"],
|
||||
output_columns=["data", "label"],
|
||||
num_parallel_workers=self.num_parallel_calls)
|
||||
data_set = data_set.shuffle(buffer_size=self.batch_size * 10)
|
||||
data_set = data_set.batch(self.batch_size, drop_remainder=True)
|
||||
data_set = data_set.repeat(repeat)
|
||||
return data_set
|
@ -0,0 +1,51 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLoss(nn.Cell):
|
||||
def __init__(self, num_cls=21, ignore_label=255):
|
||||
super(SoftmaxCrossEntropyLoss, self).__init__()
|
||||
self.one_hot = P.OneHot(axis=-1)
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.num_cls = num_cls
|
||||
self.ignore_label = ignore_label
|
||||
self.mul = P.Mul()
|
||||
self.sum = P.ReduceSum(False)
|
||||
self.div = P.RealDiv()
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, labels):
|
||||
labels_int = self.cast(labels, mstype.int32)
|
||||
labels_int = self.reshape(labels_int, (-1,))
|
||||
logits_ = self.transpose(logits, (0, 2, 3, 1))
|
||||
logits_ = self.reshape(logits_, (-1, self.num_cls))
|
||||
weights = self.not_equal(labels_int, self.ignore_label)
|
||||
weights = self.cast(weights, mstype.float32)
|
||||
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
|
||||
logits_ = self.cast(logits_, mstype.float32)
|
||||
loss = self.ce(logits_, one_hot_labels)
|
||||
loss = self.mul(weights, loss)
|
||||
loss = self.div(self.sum(loss), self.sum(weights))
|
||||
return loss
|
@ -0,0 +1,206 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class FCN8s(nn.Cell):
|
||||
def __init__(self, n_class):
|
||||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.conv1 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=3,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv2 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv3 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=128,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv4 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=256,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv5 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv6 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=512,
|
||||
out_channels=4096,
|
||||
kernel_size=7,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.conv7 = nn.SequentialCell(
|
||||
nn.Conv2d(in_channels=4096,
|
||||
out_channels=4096,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform'),
|
||||
nn.BatchNorm2d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.score_fr = nn.Conv2d(in_channels=4096,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.score_pool4 = nn.Conv2d(in_channels=512,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.score_pool3 = nn.Conv2d(in_channels=256,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=1,
|
||||
weight_init='xavier_uniform')
|
||||
|
||||
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class,
|
||||
out_channels=self.n_class,
|
||||
kernel_size=16,
|
||||
stride=8,
|
||||
weight_init='xavier_uniform')
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.conv1(x)
|
||||
p1 = self.pool1(x1)
|
||||
x2 = self.conv2(p1)
|
||||
p2 = self.pool2(x2)
|
||||
x3 = self.conv3(p2)
|
||||
p3 = self.pool3(x3)
|
||||
x4 = self.conv4(p3)
|
||||
p4 = self.pool4(x4)
|
||||
x5 = self.conv5(p4)
|
||||
p5 = self.pool5(x5)
|
||||
|
||||
x6 = self.conv6(p5)
|
||||
x7 = self.conv7(x6)
|
||||
|
||||
sf = self.score_fr(x7)
|
||||
u2 = self.upscore2(sf)
|
||||
|
||||
s4 = self.score_pool4(p4)
|
||||
f4 = s4 + u2
|
||||
u4 = self.upscore_pool4(f4)
|
||||
|
||||
s3 = self.score_pool3(p3)
|
||||
f3 = s3 + u4
|
||||
out = self.upscore8(f3)
|
||||
|
||||
return out
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,137 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train FCN8s."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
from src.data import dataset as data_generator
|
||||
from src.loss import loss
|
||||
from src.utils.lr_scheduler import CosineAnnealingLR
|
||||
from src.nets.FCN8s import FCN8s
|
||||
from src.config import FCN8s_VOC2012_cfg
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('mindspore FCN training')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
args, _ = parser.parse_known_args()
|
||||
return args
|
||||
|
||||
|
||||
def train():
|
||||
args = parse_args()
|
||||
cfg = FCN8s_VOC2012_cfg
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
# init multicards training
|
||||
if device_num > 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||
device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
# dataset
|
||||
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
||||
image_std=cfg.image_std,
|
||||
data_file=cfg.data_file,
|
||||
batch_size=cfg.batch_size,
|
||||
crop_size=cfg.crop_size,
|
||||
max_scale=cfg.max_scale,
|
||||
min_scale=cfg.min_scale,
|
||||
ignore_label=cfg.ignore_label,
|
||||
num_classes=cfg.num_classes,
|
||||
num_readers=2,
|
||||
num_parallel_calls=4,
|
||||
shard_id=args.rank,
|
||||
shard_num=args.group_size)
|
||||
dataset = dataset.get_dataset(repeat=1)
|
||||
|
||||
net = FCN8s(n_class=cfg.num_classes)
|
||||
loss_ = loss.SoftmaxCrossEntropyLoss(cfg.num_classes, cfg.ignore_label)
|
||||
|
||||
# load pretrained vgg16 parameters to init FCN8s
|
||||
if cfg.ckpt_vgg16:
|
||||
param_vgg = load_checkpoint(cfg.ckpt_vgg16)
|
||||
param_dict = {}
|
||||
for layer_id in range(1, 6):
|
||||
sub_layer_num = 2 if layer_id < 3 else 3
|
||||
for sub_layer_id in range(sub_layer_num):
|
||||
# conv param
|
||||
y_weight = 'conv{}.{}.weight'.format(layer_id, 3 * sub_layer_id)
|
||||
x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format(layer_id, sub_layer_id + 1)
|
||||
param_dict[y_weight] = param_vgg[x_weight]
|
||||
# BatchNorm param
|
||||
y_gamma = 'conv{}.{}.gamma'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
y_beta = 'conv{}.{}.beta'.format(layer_id, 3 * sub_layer_id + 1)
|
||||
x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format(layer_id, sub_layer_id + 1)
|
||||
x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format(layer_id, sub_layer_id + 1)
|
||||
param_dict[y_gamma] = param_vgg[x_gamma]
|
||||
param_dict[y_beta] = param_vgg[x_beta]
|
||||
load_param_into_net(net, param_dict)
|
||||
# load pretrained FCN8s
|
||||
elif cfg.ckpt_pre_trained:
|
||||
param_dict = load_checkpoint(cfg.ckpt_pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# optimizer
|
||||
iters_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
lr_scheduler = CosineAnnealingLR(cfg.base_lr,
|
||||
cfg.train_epochs,
|
||||
iters_per_epoch,
|
||||
cfg.train_epochs,
|
||||
warmup_epochs=0,
|
||||
eta_min=0)
|
||||
lr = Tensor(lr_scheduler.get_lr())
|
||||
|
||||
# loss scale
|
||||
manager_loss_scale = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
|
||||
loss_scale=cfg.loss_scale)
|
||||
|
||||
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
|
||||
|
||||
# callback for saving ckpts
|
||||
time_cb = TimeMonitor(data_size=iters_per_epoch)
|
||||
loss_cb = LossMonitor()
|
||||
cbs = [time_cb, loss_cb]
|
||||
|
||||
if args.rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.train_dir, config=config_ck)
|
||||
cbs.append(ckpoint_cb)
|
||||
|
||||
model.train(cfg.train_epochs, dataset, callbacks=cbs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
Loading…
Reference in new issue