Add deeplabv3 to modelzoo.

pull/5771/head
jzg 5 years ago
parent 9090812d53
commit 2169cf32a1

File diff suppressed because it is too large Load Diff

@ -12,40 +12,202 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""evaluation.""" """eval deeplabv3."""
import os
import argparse import argparse
import numpy as np
import cv2
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.md_dataset import create_dataset from src.nets import net_factory
from src.losses import OhemLoss context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
from src.miou_precision import MiouPrecision device_id=int(os.getenv('DEVICE_ID')))
from src.deeplabv3 import deeplabv3_resnet50
from src.config import config
def parse_args():
parser = argparse.ArgumentParser('mindspore deeplabv3 eval')
parser = argparse.ArgumentParser(description="Deeplabv3 evaluation")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") # val data
parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url') parser.add_argument('--data_root', type=str, default='', help='root path of val data')
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') parser.add_argument('--data_lst', type=str, default='', help='list of val data')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
args_opt = parser.parse_args() parser.add_argument('--crop_size', type=int, default=513, help='crop size')
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
print(args_opt) parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
parser.add_argument('--scales', type=float, action='append', help='scales of evaluation')
parser.add_argument('--flip', action='store_true', help='perform left-right flip')
if __name__ == "__main__": parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
args_opt.crop_size = config.crop_size parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
args_opt.base_size = config.crop_size
eval_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="eval") # model
net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
param_dict = load_checkpoint(args_opt.checkpoint_url) args, _ = parser.parse_known_args()
load_param_into_net(net, param_dict) return args
mIou = MiouPrecision(config.seg_num_classes)
metrics = {'mIou': mIou}
loss = OhemLoss(config.seg_num_classes, config.ignore_label) def cal_hist(a, b, n):
model = Model(net, loss, metrics=metrics) k = (a >= 0) & (a < n)
model.eval(eval_dataset) return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n)
def resize_long(img, long_size=513):
h, w, _ = img.shape
if h > w:
new_h = long_size
new_w = int(1.0 * long_size * w / h)
else:
new_w = long_size
new_h = int(1.0 * long_size * h / w)
imo = cv2.resize(img, (new_w, new_h))
return imo
class BuildEvalNetwork(nn.Cell):
def __init__(self, network):
super(BuildEvalNetwork, self).__init__()
self.network = network
self.softmax = nn.Softmax(axis=1)
def construct(self, input_data):
output = self.network(input_data)
output = self.softmax(output)
return output
def pre_process(args, img_, crop_size=513):
# resize
img_ = resize_long(img_, crop_size)
resize_h, resize_w, _ = img_.shape
# mean, std
image_mean = np.array(args.image_mean)
image_std = np.array(args.image_std)
img_ = (img_ - image_mean) / image_std
# pad to crop_size
pad_h = crop_size - img_.shape[0]
pad_w = crop_size - img_.shape[1]
if pad_h > 0 or pad_w > 0:
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
# hwc to chw
img_ = img_.transpose((2, 0, 1))
return img_, resize_h, resize_w
def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True):
result_lst = []
batch_size = len(img_lst)
batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32)
resize_hw = []
for l in range(batch_size):
img_ = img_lst[l]
img_, resize_h, resize_w = pre_process(args, img_, crop_size)
batch_img[l] = img_
resize_hw.append([resize_h, resize_w])
batch_img = np.ascontiguousarray(batch_img)
net_out = eval_net(Tensor(batch_img, mstype.float32))
net_out = net_out.asnumpy()
if flip:
batch_img = batch_img[:, :, :, ::-1]
net_out_flip = eval_net(Tensor(batch_img, mstype.float32))
net_out += net_out_flip.asnumpy()[:, :, :, ::-1]
for bs in range(batch_size):
probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0))
ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1]
probs_ = cv2.resize(probs_, (ori_w, ori_h))
result_lst.append(probs_)
return result_lst
def eval_batch_scales(args, eval_net, img_lst, scales,
base_crop_size=513, flip=True):
sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales]
probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip)
print(sizes_)
for crop_size_ in sizes_[1:]:
probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip)
for pl, _ in enumerate(probs_lst):
probs_lst[pl] += probs_lst_tmp[pl]
result_msk = []
for i in probs_lst:
result_msk.append(i.argmax(axis=2))
return result_msk
def net_eval():
args = parse_args()
# data list
with open(args.data_lst) as f:
img_lst = f.readlines()
# network
if args.model == 'deeplab_v3_s16':
network = net_factory.nets_map[args.model]('eval', args.num_classes, 16, args.freeze_bn)
elif args.model == 'deeplab_v3_s8':
network = net_factory.nets_map[args.model]('eval', args.num_classes, 8, args.freeze_bn)
else:
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
eval_net = BuildEvalNetwork(network)
# load model
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(eval_net, param_dict)
eval_net.set_train(False)
# evaluate
hist = np.zeros((args.num_classes, args.num_classes))
batch_img_lst = []
batch_msk_lst = []
bi = 0
image_num = 0
for i, line in enumerate(img_lst):
img_path, msk_path = line.strip().split(' ')
img_path = os.path.join(args.data_root, img_path)
msk_path = os.path.join(args.data_root, msk_path)
img_ = cv2.imread(img_path)
msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
batch_img_lst.append(img_)
batch_msk_lst.append(msk_)
bi += 1
if bi == args.batch_size:
batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales,
base_crop_size=args.crop_size, flip=args.flip)
for mi in range(args.batch_size):
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
bi = 0
batch_img_lst = []
batch_msk_lst = []
print('processed {} images'.format(i+1))
image_num = i
if bi > 0:
batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales,
base_crop_size=args.crop_size, flip=args.flip)
for mi in range(bi):
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
print('processed {} images'.format(image_num + 1))
print(hist)
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
print('per-class IoU', iu)
print('mean IoU', np.nanmean(iu))
if __name__ == '__main__':
net_eval()

@ -0,0 +1,4 @@
mindspore
numpy
Pillow
python-opencv

@ -1,3 +1,4 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -12,25 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""File operation module."""
import os
export DEVICE_ID=7
def _is_obs(url): python /PATH/TO/MODEL_ZOO_CODE/data/build_seg_data.py --data_root=/PATH/TO/DATA_ROOT \
return url.startswith("obs://") or url.startswith("s3://") --data_lst=/PATH/TO/DATA_lst.txt \
--dst_path=/PATH/TO/MINDRECORED_NAME.mindrecord \
--num_shards=8 \
def read(url, binary=False): --shuffle=True
if _is_obs(url):
# TODO read cloud file.
return None
with open(url, "rb" if binary else "r") as f:
return f.read()
def walk(url):
if _is_obs(url):
# TODO read cloud file.
return None
return os.walk(url)

@ -1,68 +0,0 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH"
echo "for example: bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH [PRETRAINED_CKPT_PATH](option)"
echo "It is better to use absolute path."
echo "=============================================================================================================="
DATA_DIR=$2
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export DEVICE_NUM=8
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0;i<DEVICE_NUM;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $i, device $DEVICE_ID"
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
env > env.log
taskset -c $cmdopt python ../train.py \
--distribute="true" \
--device_id=$DEVICE_ID \
--checkpoint_url=$PATH_CHECKPOINT \
--data_url=$DATA_DIR > log.txt 2>&1 &
cd ../
done

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.08 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=410 \
--keep_checkpoint_max=200 >log 2>&1 &
done

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=800 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.02 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=820 \
--keep_checkpoint_max=200 >log 2>&1 &
done

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.008 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=110 \
--keep_checkpoint_max=200 >log 2>&1 &
done

@ -1,34 +0,0 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH"
echo "for example: bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH"
echo "=============================================================================================================="
DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python eval.py \
--device_id=$DEVICE_ID \
--checkpoint_url=$PATH_CHECKPOINT \
--data_url=$DATA_DIR > eval.log 2>&1 &

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=32 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--scales=1.0 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--scales=1.0 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

@ -0,0 +1,41 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--flip \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &

@ -1,38 +1,44 @@
#!/bin/bash #!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the License); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# httpwww.apache.orglicensesLICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain.sh DEVICE_ID DATA_PATH"
echo "for example: bash run_standalone_train.sh DEVICE_ID DATA_PATH [PRETRAINED_CKPT_PATH](option)"
echo "=============================================================================================================="
DEVICE_ID=$1 export DEVICE_ID=5
DATA_DIR=$2 export SLOG_PRINT_TO_STDOUT=0
PATH_CHECKPOINT="" train_path=/PATH/TO/EXPERIMENTS_DIR
if [ $# == 3 ] train_code_path=/PATH/TO/MODEL_ZOO_CODE
then
PATH_CHECKPOINT=$3 if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi fi
mkdir -p ${train_path}
mkdir ${train_path}/device${DEVICE_ID}
mkdir ${train_path}/ckpt
cd ${train_path}/device${DEVICE_ID} || exit
mkdir -p ms_log python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
CUR_DIR=`pwd` --train_dir=${train_path}/ckpt \
export GLOG_log_dir=${CUR_DIR}/ms_log --train_epochs=200 \
export GLOG_logtostderr=0 --batch_size=32 \
python train.py \ --crop_size=513 \
--distribute="false" \ --base_lr=0.015 \
--device_id=$DEVICE_ID \ --lr_type=cos \
--checkpoint_url=$PATH_CHECKPOINT \ --min_scale=0.5 \
--data_url=$DATA_DIR > log.txt 2>&1 & --max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--save_steps=1500 \
--keep_checkpoint_max=200 >log 2>&1 &

@ -1,23 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init DeepLabv3."""
from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50
from .backbone import *
__all__ = [
"ASPP", "DeepLabV3", "deeplabv3_resnet50"
]
__all__.extend(backbone.__all__)

@ -1,21 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init backbone."""
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
RootBlockBeta, resnet50_dl
__all__ = [
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", "resnet50_dl"
]

@ -0,0 +1,72 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import numpy as np
from mindspore.mindrecord import FileWriter
seg_schema = {"file_name": {"type": "string"}, "label": {"type": "bytes"}, "data": {"type": "bytes"}}
def parse_args():
parser = argparse.ArgumentParser('mindrecord')
parser.add_argument('--data_root', type=str, default='', help='root path of data')
parser.add_argument('--data_lst', type=str, default='', help='list of data')
parser.add_argument('--dst_path', type=str, default='', help='save path of mindrecords')
parser.add_argument('--num_shards', type=int, default=8, help='number of shards')
parser.add_argument('--shuffle', type=bool, default=True, help='shuffle or not')
parser_args, _ = parser.parse_known_args()
return parser_args
if __name__ == '__main__':
args = parse_args()
datas = []
with open(args.data_lst) as f:
lines = f.readlines()
if args.shuffle:
np.random.shuffle(lines)
dst_dir = '/'.join(args.dst_path.split('/')[:-1])
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
print('number of samples:', len(lines))
writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards)
writer.add_schema(seg_schema, "seg_schema")
cnt = 0
for l in lines:
img_path, label_path = l.strip().split(' ')
sample_ = {"file_name": img_path.split('/')[-1]}
with open(os.path.join(args.data_root, img_path), 'rb') as f:
sample_['data'] = f.read()
with open(os.path.join(args.data_root, label_path), 'rb') as f:
sample_['label'] = f.read()
datas.append(sample_)
cnt += 1
if cnt % 1000 == 0:
writer.write_raw_data(datas)
print('number of samples written:', cnt)
datas = []
if datas:
writer.write_raw_data(datas)
writer.commit()
print('number of samples written:', cnt)

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import cv2
import numpy as np
import mindspore.dataset as de
class SegDataset:
def __init__(self,
image_mean,
image_std,
data_file='',
batch_size=32,
crop_size=512,
max_scale=2.0,
min_scale=0.5,
ignore_label=255,
num_classes=21,
num_readers=2,
num_parallel_calls=4,
shard_id=None,
shard_num=None):
self.data_file = data_file
self.batch_size = batch_size
self.crop_size = crop_size
self.image_mean = np.array(image_mean, dtype=np.float32)
self.image_std = np.array(image_std, dtype=np.float32)
self.max_scale = max_scale
self.min_scale = min_scale
self.ignore_label = ignore_label
self.num_classes = num_classes
self.num_readers = num_readers
self.num_parallel_calls = num_parallel_calls
self.shard_id = shard_id
self.shard_num = shard_num
assert max_scale > min_scale
def preprocess_(self, image, label):
# bgr image
image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
sc = np.random.uniform(self.min_scale, self.max_scale)
new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
image_out = (image_out - self.image_mean) / self.image_std
h_, w_ = max(new_h, self.crop_size), max(new_w, self.crop_size)
pad_h, pad_w = h_ - new_h, w_ - new_w
if pad_h > 0 or pad_w > 0:
image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
offset_h = np.random.randint(0, h_ - self.crop_size + 1)
offset_w = np.random.randint(0, w_ - self.crop_size + 1)
image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
if np.random.uniform(0.0, 1.0) > 0.5:
image_out = image_out[:, ::-1, :]
label_out = label_out[:, ::-1]
image_out = image_out.transpose((2, 0, 1))
image_out = image_out.copy()
label_out = label_out.copy()
return image_out, label_out
def get_dataset(self, repeat=1):
data_set = de.MindDataset(dataset_file=self.data_file, columns_list=["data", "label"],
shuffle=True, num_parallel_workers=self.num_readers,
num_shards=self.shard_num, shard_id=self.shard_id)
transforms_list = self.preprocess_
data_set = data_set.map(input_columns=["data", "label"], output_columns=["data", "label"],
operations=transforms_list, num_parallel_workers=self.num_parallel_calls)
data_set = data_set.shuffle(buffer_size=self.batch_size * 10)
data_set = data_set.batch(self.batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat)
return data_set

File diff suppressed because it is too large Load Diff

@ -1,84 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Process Dataset."""
import abc
import os
import time
from .utils.adapter import get_raw_samples, read_image
class BaseDataset:
"""
Create dataset.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage):
self.data_url = data_url
self.usage = usage
self.cur_index = 0
self.samples = []
_s_time = time.time()
self._load_samples()
_e_time = time.time()
print(f"load samples success~, time cost = {_e_time - _s_time}")
def __getitem__(self, item):
sample = self.samples[item]
return self._next_data(sample)
def __len__(self):
return len(self.samples)
@staticmethod
def _next_data(sample):
image_path = sample[0]
mask_image_path = sample[1]
image = read_image(image_path)
mask_image = read_image(mask_image_path)
return [image, mask_image]
@abc.abstractmethod
def _load_samples(self):
pass
class HwVocRawDataset(BaseDataset):
"""
Create dataset with raw data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage="train"):
super().__init__(data_url, usage)
def _load_samples(self):
try:
self.samples = get_raw_samples(os.path.join(self.data_url, self.usage))
except Exception as e:
print("load HwVocRawDataset failed!!!")
raise e

@ -0,0 +1,50 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
class SoftmaxCrossEntropyLoss(nn.Cell):
def __init__(self, num_cls=21, ignore_label=255):
super(SoftmaxCrossEntropyLoss, self).__init__()
self.one_hot = P.OneHot(axis=-1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.not_equal = P.NotEqual()
self.num_cls = num_cls
self.ignore_label = ignore_label
self.mul = P.Mul()
self.sum = P.ReduceSum(False)
self.div = P.RealDiv()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
def construct(self, logits, labels):
labels_int = self.cast(labels, mstype.int32)
labels_int = self.reshape(labels_int, (-1,))
logits_ = self.transpose(logits, (0, 2, 3, 1))
logits_ = self.reshape(logits_, (-1, self.num_cls))
weights = self.not_equal(labels_int, self.ignore_label)
weights = self.cast(weights, mstype.float32)
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
loss = self.ce(logits_, one_hot_labels)
loss = self.mul(weights, loss)
loss = self.div(self.sum(loss), self.sum(weights))
return loss

@ -1,65 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""OhemLoss."""
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class OhemLoss(nn.Cell):
"""Ohem loss cell."""
def __init__(self, num, ignore_label):
super(OhemLoss, self).__init__()
self.mul = P.Mul()
self.shape = P.Shape()
self.one_hot = nn.OneHot(-1, num, 1.0, 0.0)
self.squeeze = P.Squeeze()
self.num = num
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.select = P.Select()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.not_equal = P.NotEqual()
self.equal = P.Equal()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.fill = P.Fill()
self.transpose = P.Transpose()
self.ignore_label = ignore_label
self.loss_weight = 1.0
def construct(self, logits, labels):
if not self.training:
return 0
logits = self.transpose(logits, (0, 2, 3, 1))
logits = self.reshape(logits, (-1, self.num))
labels = F.cast(labels, mstype.int32)
labels = self.reshape(labels, (-1,))
one_hot_labels = self.one_hot(labels)
losses = self.cross_entropy(logits, one_hot_labels)[0]
weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight
weighted_losses = self.mul(losses, weights)
loss = self.reduce_sum(weighted_losses, (0,))
zeros = self.fill(mstype.float32, self.shape(weights), 0.0)
ones = self.fill(mstype.float32, self.shape(weights), 1.0)
present = self.select(self.equal(weights, zeros), zeros, ones)
present = self.reduce_sum(present, (0,))
zeros = self.fill(mstype.float32, self.shape(present), 0.0)
min_control = self.fill(mstype.float32, self.shape(present), 1.0)
present = self.select(self.equal(present, zeros), min_control, present)
loss = loss / present
return loss

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save