add for connect table fix some bug fix pylint fix for create dataset fix dataset bug fix for create dataset problem add for svt icdar2015 convert script fix for ctpn problem fix for vgg16pull/12121/head
parent
78c733ffbe
commit
ae27c383fa
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,118 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for CTPN"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.ctpn import CTPN
|
||||
from src.config import config
|
||||
from src.dataset import create_ctpn_dataset
|
||||
from src.text_connector.detector import detect
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CTPN evaluation")
|
||||
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.")
|
||||
parser.add_argument("--image_path", type=str, default="", help="Image path.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
||||
def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
|
||||
"""ctpn infer."""
|
||||
print("ckpt path is {}".format(ckpt_path))
|
||||
ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False)
|
||||
config.batch_size = config.test_batch_size
|
||||
total = ds.get_dataset_size()
|
||||
print("*************total dataset size is {}".format(total))
|
||||
net = CTPN(config, is_training=False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
eval_iter = 0
|
||||
|
||||
print("\n========================================\n")
|
||||
print("Processing, please wait a moment.")
|
||||
img_basenames = []
|
||||
output_dir = os.path.join(os.getcwd(), "submit")
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
for file in os.listdir(img_dir):
|
||||
img_basenames.append(os.path.basename(file))
|
||||
for data in ds.create_dict_iterator():
|
||||
img_data = data['image']
|
||||
img_metas = data['image_shape']
|
||||
gt_bboxes = data['box']
|
||||
gt_labels = data['label']
|
||||
gt_num = data['valid_num']
|
||||
|
||||
start = time.time()
|
||||
# run net
|
||||
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
|
||||
gt_bboxes = gt_bboxes.asnumpy()
|
||||
gt_labels = gt_labels.asnumpy()
|
||||
gt_num = gt_num.asnumpy().astype(bool)
|
||||
end = time.time()
|
||||
proposal = output[0]
|
||||
proposal_mask = output[1]
|
||||
print("start to draw pic")
|
||||
for j in range(config.test_batch_size):
|
||||
img = img_basenames[config.test_batch_size * eval_iter + j]
|
||||
all_box_tmp = proposal[j].asnumpy()
|
||||
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
|
||||
using_boxes_mask = all_box_tmp * all_mask_tmp
|
||||
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
|
||||
scores = using_boxes_mask[:, 4].astype(np.float32)
|
||||
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
|
||||
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
|
||||
from PIL import Image, ImageDraw
|
||||
im = Image.open(img_dir + '/' + img)
|
||||
draw = ImageDraw.Draw(im)
|
||||
image_h = img_metas.asnumpy()[j][2]
|
||||
image_w = img_metas.asnumpy()[j][3]
|
||||
gt_boxs = gt_bboxes[j][gt_num[j], :]
|
||||
for gt_box in gt_boxs:
|
||||
gt_x1 = gt_box[0] / image_w
|
||||
gt_y1 = gt_box[1] / image_h
|
||||
gt_x2 = gt_box[2] / image_w
|
||||
gt_y2 = gt_box[3] / image_h
|
||||
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
|
||||
fill='green', width=2)
|
||||
file_name = "res_" + img.replace("jpg", "txt")
|
||||
output_file = os.path.join(output_dir, file_name)
|
||||
f = open(output_file, 'w')
|
||||
for bbox in bboxes:
|
||||
x1 = bbox[0] / image_w
|
||||
y1 = bbox[1] / image_h
|
||||
x2 = bbox[2] / image_w
|
||||
y2 = bbox[3] / image_h
|
||||
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
|
||||
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
|
||||
f.write(str_tmp)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
im.save(img)
|
||||
percent = round(eval_iter / total * 100, 2)
|
||||
eval_iter = eval_iter + 1
|
||||
print("Iter {} cost time {}".format(eval_iter, end - start))
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
|
||||
|
||||
if __name__ == '__main__':
|
||||
ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path)
|
@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
TASK_TYPE=$2
|
||||
PATH2=$(get_real_path $3)
|
||||
echo $PATH2
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --task_type=$TASK_TYPE --pre_trained=$PATH2 &> log &
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,80 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
IMAGE_PATH=$(get_real_path $1)
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
CHECKPOINT_PATH=$(get_real_path $3)
|
||||
echo $IMAGE_PATH
|
||||
echo $DATASET_PATH
|
||||
echo $CHECKPOINT_PATH
|
||||
|
||||
if [ ! -d $IMAGE_PATH ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$DATASET_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $CHECKPOINT_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
for file in "${CHECKPOINT_PATH}"/*.ckpt
|
||||
do
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval
|
||||
env > env.log
|
||||
CHECKPOINT_FILE_PATH=$file
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --device_id=$DEVICE_ID --image_path=$IMAGE_PATH --dataset_path=$DATASET_PATH --checkpoint_path=$CHECKPOINT_FILE_PATH &> log
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ./submit
|
||||
file_base_name=$(basename $file)
|
||||
zip -r ../../submit_${file_base_name%.*}.zip *.txt
|
||||
cd ../../
|
||||
done
|
@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
TASK_TYPE=$1
|
||||
PRETRAINED_PATH=$(get_real_path $2)
|
||||
echo $PRETRAINED_PATH
|
||||
if [ ! -f $PRETRAINED_PATH ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH &> log &
|
||||
cd ..
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class BoundingBoxDecode(nn.Cell):
|
||||
"""
|
||||
BoundintBox Decoder.
|
||||
|
||||
Returns:
|
||||
pred_box(Tensor): decoder bounding boxes.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BoundingBoxDecode, self).__init__()
|
||||
self.split = P.Split(axis=1, output_num=4)
|
||||
self.ones = 1.0
|
||||
self.half = 0.5
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, bboxes, deltas):
|
||||
"""
|
||||
boxes(Tensor): boundingbox.
|
||||
deltas(Tensor): delta between boundingboxs and anchors.
|
||||
"""
|
||||
x1, y1, x2, y2 = self.split(bboxes)
|
||||
width = x2 - x1 + self.ones
|
||||
height = y2 - y1 + self.ones
|
||||
ctr_x = x1 + self.half * width
|
||||
ctr_y = y1 + self.half * height
|
||||
_, dy, _, dh = self.split(deltas)
|
||||
pred_ctr_x = ctr_x
|
||||
pred_ctr_y = dy * height + ctr_y
|
||||
pred_w = width
|
||||
pred_h = self.exp(dh) * height
|
||||
|
||||
x1 = pred_ctr_x - self.half * pred_w
|
||||
y1 = pred_ctr_y - self.half * pred_h
|
||||
x2 = pred_ctr_x + self.half * pred_w
|
||||
y2 = pred_ctr_y + self.half * pred_h
|
||||
pred_box = self.concat((x1, y1, x2, y2))
|
||||
return pred_box
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class BoundingBoxEncode(nn.Cell):
|
||||
"""
|
||||
BoundintBox Decoder.
|
||||
|
||||
Returns:
|
||||
pred_box(Tensor): decoder bounding boxes.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BoundingBoxEncode, self).__init__()
|
||||
self.split = P.Split(axis=1, output_num=4)
|
||||
self.ones = 1.0
|
||||
self.half = 0.5
|
||||
self.log = P.Log()
|
||||
self.concat = P.Concat(axis=1)
|
||||
def construct(self, anchor_box, gt_box):
|
||||
"""
|
||||
boxes(Tensor): boundingbox.
|
||||
deltas(Tensor): delta between boundingboxs and anchors.
|
||||
"""
|
||||
x1, y1, x2, y2 = self.split(anchor_box)
|
||||
width = x2 - x1 + self.ones
|
||||
height = y2 - y1 + self.ones
|
||||
ctr_x = x1 + self.half * width
|
||||
ctr_y = y1 + self.half * height
|
||||
gt_x1, gt_y1, gt_x2, gt_y2 = self.split(gt_box)
|
||||
gt_width = gt_x2 - gt_x1 + self.ones
|
||||
gt_height = gt_y2 - gt_y1 + self.ones
|
||||
ctr_gt_x = gt_x1 + self.half * gt_width
|
||||
ctr_gt_y = gt_y1 + self.half * gt_height
|
||||
|
||||
target_dx = (ctr_gt_x - ctr_x) / width
|
||||
target_dy = (ctr_gt_y - ctr_y) / height
|
||||
dw = gt_width / width
|
||||
dh = gt_height / height
|
||||
target_dw = self.log(dw)
|
||||
target_dh = self.log(dh)
|
||||
deltas = self.concat((target_dx, target_dy, target_dw, target_dh))
|
||||
return deltas
|
@ -0,0 +1,73 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn anchor generator."""
|
||||
import numpy as np
|
||||
class AnchorGenerator():
|
||||
"""Anchor generator for FasterRcnn."""
|
||||
def __init__(self, config):
|
||||
"""Anchor generator init method."""
|
||||
self.base_size = config.anchor_base
|
||||
self.num_anchor = config.num_anchors
|
||||
self.anchor_height = config.anchor_height
|
||||
self.anchor_width = config.anchor_width
|
||||
self.size = self.gen_anchor_size()
|
||||
self.base_anchors = self.gen_base_anchors()
|
||||
|
||||
def gen_base_anchors(self):
|
||||
"""Generate a single anchor."""
|
||||
base_anchor = np.array([0, 0, self.base_size - 1, self.base_size - 1], np.int32)
|
||||
anchors = np.zeros((len(self.size), 4), np.int32)
|
||||
index = 0
|
||||
for h, w in self.size:
|
||||
anchors[index] = self.scale_anchor(base_anchor, h, w)
|
||||
index += 1
|
||||
return anchors
|
||||
|
||||
def gen_anchor_size(self):
|
||||
"""Generate a list of anchor size"""
|
||||
size = []
|
||||
for width in self.anchor_width:
|
||||
for height in self.anchor_height:
|
||||
size.append((height, width))
|
||||
return size
|
||||
|
||||
def scale_anchor(self, anchor, h, w):
|
||||
x_ctr = (anchor[0] + anchor[2]) * 0.5
|
||||
y_ctr = (anchor[1] + anchor[3]) * 0.5
|
||||
scaled_anchor = anchor.copy()
|
||||
scaled_anchor[0] = x_ctr - w / 2 # xmin
|
||||
scaled_anchor[2] = x_ctr + w / 2 # xmax
|
||||
scaled_anchor[1] = y_ctr - h / 2 # ymin
|
||||
scaled_anchor[3] = y_ctr + h / 2 # ymax
|
||||
return scaled_anchor
|
||||
|
||||
def _meshgrid(self, x, y):
|
||||
"""Generate grid."""
|
||||
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
|
||||
yy = np.repeat(y, len(x))
|
||||
return xx, yy
|
||||
|
||||
def grid_anchors(self, featmap_size, stride=16):
|
||||
"""Generate anchor list."""
|
||||
base_anchors = self.base_anchors
|
||||
feat_h, feat_w = featmap_size
|
||||
shift_x = np.arange(0, feat_w) * stride
|
||||
shift_y = np.arange(0, feat_h) * stride
|
||||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||||
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
|
||||
shifts = shifts.astype(base_anchors.dtype)
|
||||
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
|
||||
all_anchors = all_anchors.reshape(-1, 4)
|
||||
return all_anchors
|
@ -0,0 +1,152 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn positive and negative sample screening for RPN."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.CTPN.BoundingBoxEncode import BoundingBoxEncode
|
||||
|
||||
|
||||
class BboxAssignSample(nn.Cell):
|
||||
"""
|
||||
Bbox assigner and sampler definition.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_bboxes (int): The anchor nums.
|
||||
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
|
||||
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
|
||||
labels: label for every bboxes, (batch_size, num_bboxes, 1)
|
||||
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
|
||||
|
||||
Examples:
|
||||
BboxAssignSample(config, 2, 1024, True)
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
|
||||
super(BboxAssignSample, self).__init__()
|
||||
cfg = config
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
|
||||
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
|
||||
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
|
||||
self.zero_thr = Tensor(0.0, mstype.float16)
|
||||
|
||||
self.num_bboxes = num_bboxes
|
||||
self.num_gts = cfg.num_gts
|
||||
self.num_expected_pos = cfg.num_expected_pos
|
||||
self.num_expected_neg = cfg.num_expected_neg
|
||||
self.add_gt_as_proposals = add_gt_as_proposals
|
||||
|
||||
if self.add_gt_as_proposals:
|
||||
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
|
||||
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.max_gt = P.ArgMaxWithValue(axis=0)
|
||||
self.max_anchor = P.ArgMaxWithValue(axis=1)
|
||||
self.sum_inds = P.ReduceSum()
|
||||
self.iou = P.IOU()
|
||||
self.greaterequal = P.GreaterEqual()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.less = P.Less()
|
||||
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
|
||||
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
|
||||
self.reshape = P.Reshape()
|
||||
self.equal = P.Equal()
|
||||
self.bounding_box_encode = BoundingBoxEncode()
|
||||
self.scatterNdUpdate = P.ScatterNdUpdate()
|
||||
self.scatterNd = P.ScatterNd()
|
||||
self.logicalnot = P.LogicalNot()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
|
||||
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
|
||||
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
|
||||
|
||||
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
|
||||
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
|
||||
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
|
||||
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
|
||||
self.print = P.Print()
|
||||
|
||||
|
||||
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
|
||||
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
|
||||
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
|
||||
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
|
||||
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
|
||||
overlaps = self.iou(bboxes, gt_bboxes_i)
|
||||
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
|
||||
_, max_overlaps_w_ac = self.max_anchor(overlaps)
|
||||
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
|
||||
self.less(max_overlaps_w_gt, self.neg_iou_thr))
|
||||
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
|
||||
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
|
||||
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
|
||||
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
|
||||
assigned_gt_inds4 = assigned_gt_inds3
|
||||
for j in range(self.num_gts):
|
||||
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
|
||||
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
|
||||
|
||||
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
|
||||
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
|
||||
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
|
||||
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
|
||||
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
|
||||
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
|
||||
pos_check_valid = self.sum_inds(pos_check_valid, -1)
|
||||
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
|
||||
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
|
||||
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
|
||||
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
|
||||
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
|
||||
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
|
||||
|
||||
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
|
||||
num_pos = self.sum_inds(num_pos, -1)
|
||||
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
|
||||
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
|
||||
pos_bboxes_ = self.gatherND(bboxes, pos_index)
|
||||
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
|
||||
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
|
||||
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
|
||||
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
|
||||
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
|
||||
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
|
||||
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
|
||||
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
|
||||
total_index = self.concat((pos_index, neg_index))
|
||||
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
|
||||
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
|
||||
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
|
||||
labels_total, self.cast(label_weights_total, mstype.bool_)
|
@ -0,0 +1,190 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn proposal generator."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from src.CTPN.BoundingBoxDecode import BoundingBoxDecode
|
||||
|
||||
class Proposal(nn.Cell):
|
||||
"""
|
||||
Proposal subnet.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
batch_size (int): Batchsize.
|
||||
num_classes (int) - Class number.
|
||||
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
|
||||
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
|
||||
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor,(proposal, mask).
|
||||
|
||||
Examples:
|
||||
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
|
||||
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
num_classes,
|
||||
use_sigmoid_cls,
|
||||
target_means=(.0, .0, .0, .0),
|
||||
target_stds=(1.0, 1.0, 1.0, 1.0)
|
||||
):
|
||||
super(Proposal, self).__init__()
|
||||
cfg = config
|
||||
self.batch_size = batch_size
|
||||
self.num_classes = num_classes
|
||||
self.target_means = target_means
|
||||
self.target_stds = target_stds
|
||||
self.use_sigmoid_cls = config.use_sigmoid_cls
|
||||
|
||||
if self.use_sigmoid_cls:
|
||||
self.cls_out_channels = 1
|
||||
self.activation = P.Sigmoid()
|
||||
self.reshape_shape = (-1, 1)
|
||||
else:
|
||||
self.cls_out_channels = num_classes
|
||||
self.activation = P.Softmax(axis=1)
|
||||
self.reshape_shape = (-1, 2)
|
||||
|
||||
if self.cls_out_channels <= 0:
|
||||
raise ValueError('num_classes={} is too small'.format(num_classes))
|
||||
|
||||
self.num_pre = cfg.rpn_proposal_nms_pre
|
||||
self.min_box_size = cfg.rpn_proposal_min_bbox_size
|
||||
self.nms_thr = cfg.rpn_proposal_nms_thr
|
||||
self.nms_post = cfg.rpn_proposal_nms_post
|
||||
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
|
||||
self.max_num = cfg.rpn_proposal_max_num
|
||||
|
||||
# Op Define
|
||||
self.squeeze = P.Squeeze()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.feature_shapes = cfg.feature_shapes
|
||||
|
||||
self.transpose_shape = (1, 2, 0)
|
||||
|
||||
self.decode = BoundingBoxDecode()
|
||||
|
||||
self.nms = P.NMSWithMask(self.nms_thr)
|
||||
self.concat_axis0 = P.Concat(axis=0)
|
||||
self.concat_axis1 = P.Concat(axis=1)
|
||||
self.split = P.Split(axis=1, output_num=5)
|
||||
self.min = P.Minimum()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.slice = P.Slice()
|
||||
self.select = P.Select()
|
||||
self.greater = P.Greater()
|
||||
self.transpose = P.Transpose()
|
||||
self.tile = P.Tile()
|
||||
self.set_train_local(config, training=True)
|
||||
|
||||
self.multi_10 = Tensor(10.0, mstype.float16)
|
||||
|
||||
def set_train_local(self, config, training=False):
|
||||
"""Set training flag."""
|
||||
self.training_local = training
|
||||
cfg = config
|
||||
self.topK_stage1 = ()
|
||||
self.topK_shape = ()
|
||||
total_max_topk_input = 0
|
||||
if not self.training_local:
|
||||
self.num_pre = cfg.rpn_nms_pre
|
||||
self.min_box_size = cfg.rpn_min_bbox_min_size
|
||||
self.nms_thr = cfg.rpn_nms_thr
|
||||
self.nms_post = cfg.rpn_nms_post
|
||||
self.max_num = cfg.rpn_max_num
|
||||
k_num = self.num_pre
|
||||
total_max_topk_input = k_num
|
||||
self.topK_stage1 = k_num
|
||||
self.topK_shape = (k_num, 1)
|
||||
|
||||
self.topKv2 = P.TopK(sorted=True)
|
||||
self.topK_shape_stage2 = (self.max_num, 1)
|
||||
self.min_float_num = -65536.0
|
||||
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
|
||||
self.shape = P.Shape()
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
|
||||
proposals_tuple = ()
|
||||
masks_tuple = ()
|
||||
for img_id in range(self.batch_size):
|
||||
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[img_id:img_id+1:1, ::, ::, ::])
|
||||
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[img_id:img_id+1:1, ::, ::, ::])
|
||||
proposals, masks = self.get_bboxes_single(rpn_cls_score_i, rpn_bbox_pred_i, anchor_list)
|
||||
proposals_tuple += (proposals,)
|
||||
masks_tuple += (masks,)
|
||||
return proposals_tuple, masks_tuple
|
||||
|
||||
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
|
||||
"""Get proposal boundingbox."""
|
||||
mlvl_proposals = ()
|
||||
mlvl_mask = ()
|
||||
|
||||
rpn_cls_score = self.transpose(cls_scores, self.transpose_shape)
|
||||
rpn_bbox_pred = self.transpose(bbox_preds, self.transpose_shape)
|
||||
anchors = mlvl_anchors
|
||||
|
||||
# (H, W, A*2)
|
||||
rpn_cls_score_shape = self.shape(rpn_cls_score)
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, (rpn_cls_score_shape[0], \
|
||||
rpn_cls_score_shape[1], -1, self.cls_out_channels))
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
|
||||
rpn_cls_score = self.activation(rpn_cls_score)
|
||||
if self.use_sigmoid_cls:
|
||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score), mstype.float16)
|
||||
else:
|
||||
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 1]), mstype.float16)
|
||||
|
||||
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
|
||||
|
||||
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.num_pre)
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape)
|
||||
|
||||
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
|
||||
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
|
||||
|
||||
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
|
||||
|
||||
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape)))
|
||||
proposals, _, mask_valid = self.nms(proposals_decode)
|
||||
|
||||
mlvl_proposals = mlvl_proposals + (proposals,)
|
||||
mlvl_mask = mlvl_mask + (mask_valid,)
|
||||
|
||||
proposals = self.concat_axis0(mlvl_proposals)
|
||||
masks = self.concat_axis0(mlvl_mask)
|
||||
|
||||
_, _, _, _, scores = self.split(proposals)
|
||||
scores = self.squeeze(scores)
|
||||
topk_mask = self.cast(self.topK_mask, mstype.float16)
|
||||
scores_using = self.select(masks, scores, topk_mask)
|
||||
|
||||
_, topk_inds = self.topKv2(scores_using, self.max_num)
|
||||
|
||||
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
|
||||
proposals = self.gatherND(proposals, topk_inds)
|
||||
masks = self.gatherND(masks, topk_inds)
|
||||
return proposals, masks
|
@ -0,0 +1,228 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""RPN for fasterRCNN"""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from src.CTPN.bbox_assign_sample import BboxAssignSample
|
||||
|
||||
class RpnRegClsBlock(nn.Cell):
|
||||
"""
|
||||
Rpn reg cls block for rpn layer
|
||||
|
||||
Args:
|
||||
config(EasyDict) - Network construction config.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels):
|
||||
super(RpnRegClsBlock, self).__init__()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, 2*config.hidden_size)
|
||||
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
|
||||
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
|
||||
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
|
||||
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
|
||||
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
|
||||
self.transpose = P.Transpose()
|
||||
self.print = P.Print()
|
||||
self.dropout = nn.Dropout(0.8)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.reshape(x, self.shape)
|
||||
x = self.lstm_fc(x)
|
||||
x1 = self.rpn_cls(x)
|
||||
x1 = self.reshape(x1, self.shape1)
|
||||
x1 = self.transpose(x1, (2, 1, 0))
|
||||
x1 = self.reshape(x1, self.shape2)
|
||||
x1 = self.transpose(x1, (1, 0, 2, 3))
|
||||
x2 = self.rpn_reg(x)
|
||||
x2 = self.reshape(x2, self.shape1)
|
||||
x2 = self.transpose(x2, (2, 1, 0))
|
||||
x2 = self.reshape(x2, self.shape2)
|
||||
x2 = self.transpose(x2, (1, 0, 2, 3))
|
||||
return x1, x2
|
||||
|
||||
class RPN(nn.Cell):
|
||||
"""
|
||||
ROI proposal network..
|
||||
|
||||
Args:
|
||||
config (dict) - Config.
|
||||
batch_size (int) - Batchsize.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor.
|
||||
|
||||
Examples:
|
||||
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
|
||||
num_anchors=3, cls_out_channels=512)
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
batch_size,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
num_anchors,
|
||||
cls_out_channels):
|
||||
super(RPN, self).__init__()
|
||||
cfg_rpn = config
|
||||
self.cfg = config
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.feature_anchor_shape = cfg_rpn.feature_shapes
|
||||
self.feature_anchor_shape = self.feature_anchor_shape[0] * \
|
||||
self.feature_anchor_shape[1] * num_anchors * batch_size
|
||||
self.num_anchors = num_anchors
|
||||
self.batch_size = batch_size
|
||||
self.test_batch_size = cfg_rpn.test_batch_size
|
||||
self.num_layers = 1
|
||||
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
|
||||
self.use_sigmoid_cls = config.use_sigmoid_cls
|
||||
if config.use_sigmoid_cls:
|
||||
self.reshape_shape_cls = (-1,)
|
||||
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
|
||||
cls_out_channels = 1
|
||||
else:
|
||||
self.reshape_shape_cls = (-1, cls_out_channels)
|
||||
self.loss_cls = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none")
|
||||
self.rpn_convs_list = self._make_rpn_layer(self.num_layers, in_channels, feat_channels,\
|
||||
num_anchors, cls_out_channels)
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.fill = P.Fill()
|
||||
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
|
||||
|
||||
self.trans_shape = (0, 2, 3, 1)
|
||||
|
||||
self.reshape_shape_reg = (-1, 4)
|
||||
self.softmax = nn.Softmax()
|
||||
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
|
||||
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
|
||||
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
|
||||
self.CheckValid = P.CheckValid()
|
||||
self.sum_loss = P.ReduceSum()
|
||||
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.tile = P.Tile()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
|
||||
self.print = P.Print()
|
||||
|
||||
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
|
||||
"""
|
||||
make rpn layer for rpn proposal network
|
||||
|
||||
Args:
|
||||
num_layers (int) - layer num.
|
||||
in_channels (int) - Input channels of shared convolution.
|
||||
feat_channels (int) - Output channels of shared convolution.
|
||||
num_anchors (int) - The anchor number.
|
||||
cls_out_channels (int) - Output channels of classification convolution.
|
||||
|
||||
Returns:
|
||||
List, list of RpnRegClsBlock cells.
|
||||
"""
|
||||
rpn_layer = RpnRegClsBlock(self.cfg, in_channels, feat_channels, num_anchors, cls_out_channels)
|
||||
return rpn_layer
|
||||
|
||||
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
|
||||
'''
|
||||
inputs(Tensor): Inputs tensor from lstm.
|
||||
img_metas(Tensor): Image shape.
|
||||
anchor_list(Tensor): Total anchor list.
|
||||
gt_labels(Tensor): Ground truth labels.
|
||||
gt_valids(Tensor): Whether ground truth is valid.
|
||||
'''
|
||||
rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs)
|
||||
rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape)
|
||||
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls)
|
||||
rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape)
|
||||
rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg)
|
||||
output = ()
|
||||
bbox_targets = ()
|
||||
bbox_weights = ()
|
||||
labels = ()
|
||||
label_weights = ()
|
||||
if self.training:
|
||||
for i in range(self.batch_size):
|
||||
valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\
|
||||
mstype.int32)
|
||||
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
|
||||
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
|
||||
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
|
||||
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
|
||||
gt_labels_i,
|
||||
self.cast(valid_flag_list,
|
||||
mstype.bool_),
|
||||
anchor_list, gt_valids_i)
|
||||
bbox_weight = self.cast(bbox_weight, mstype.float16)
|
||||
label_weight = self.cast(label_weight, mstype.float16)
|
||||
bbox_targets += (bbox_target,)
|
||||
bbox_weights += (bbox_weight,)
|
||||
labels += (label,)
|
||||
label_weights += (label_weight,)
|
||||
bbox_target_with_batchsize = self.concat(bbox_targets)
|
||||
bbox_weight_with_batchsize = self.concat(bbox_weights)
|
||||
label_with_batchsize = self.concat(labels)
|
||||
label_weight_with_batchsize = self.concat(label_weights)
|
||||
|
||||
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
|
||||
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
|
||||
label_ = F.stop_gradient(label_with_batchsize)
|
||||
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
|
||||
rpn_cls_score = self.cast(rpn_cls_score, mstype.float32)
|
||||
if self.use_sigmoid_cls:
|
||||
label_ = self.cast(label_, mstype.float32)
|
||||
loss_cls = self.loss_cls(rpn_cls_score, label_)
|
||||
loss_cls = loss_cls * label_weight_
|
||||
loss_cls = self.sum_loss(loss_cls, (0,)) / self.num_expected_total
|
||||
rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32)
|
||||
bbox_target_ = self.cast(bbox_target_, mstype.float32)
|
||||
loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_)
|
||||
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)), (1, 4))
|
||||
loss_reg = loss_reg * bbox_weight_
|
||||
loss_reg = self.sum_loss(loss_reg, (1,))
|
||||
loss_reg = self.sum_loss(loss_reg, (0,)) / self.num_expected_total
|
||||
loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg
|
||||
output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori, loss_cls, loss_reg)
|
||||
else:
|
||||
output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori, self.placeh1, self.placeh1)
|
||||
return output
|
@ -0,0 +1,177 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
''''weight initialize'''
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
|
||||
"""Batchnorm2D wrapper."""
|
||||
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
||||
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
||||
|
||||
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
|
||||
beta_init=beta_init, moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
|
||||
"""Conv2D wrapper."""
|
||||
weights = 'ones'
|
||||
layers = []
|
||||
conv = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=False)
|
||||
if not weights_update:
|
||||
conv.weight.requires_grad = False
|
||||
layers += [conv]
|
||||
layers += [_BatchNorm2dInit(out_channels)]
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
def _fc(in_channels, out_channels):
|
||||
'''full connection layer'''
|
||||
weight = _weight_variable((out_channels, in_channels))
|
||||
bias = _weight_variable((out_channels,))
|
||||
return nn.Dense(in_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
class VGG16FeatureExtraction(nn.Cell):
|
||||
def __init__(self, weights_update=False):
|
||||
"""
|
||||
VGG16 feature extraction
|
||||
|
||||
Args:
|
||||
weights_updata(bool): whether update weights for top two layers, default is False.
|
||||
"""
|
||||
super(VGG16FeatureExtraction, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
|
||||
self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3,\
|
||||
padding=1, weights_update=weights_update)
|
||||
|
||||
self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1)
|
||||
self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
|
||||
self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
|
||||
|
||||
self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
|
||||
self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 3, 224, 224)
|
||||
:return:
|
||||
"""
|
||||
x = self.cast(x, mstype.float32)
|
||||
x = self.conv1_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv2_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv3_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv3_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv3_3(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv4_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_3(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.conv5_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv5_2(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv5_3(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class VGG16Classfier(nn.Cell):
|
||||
def __init__(self):
|
||||
"""VGG16 classfier structure"""
|
||||
super(VGG16Classfier, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
|
||||
self.fc2 = _fc(in_channels=4096, out_channels=4096)
|
||||
self.batch_size = 32
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 512, 7, 7)
|
||||
:return:
|
||||
"""
|
||||
x = self.reshape(x, (self.batch_size, 7*7*512))
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class VGG16(nn.Cell):
|
||||
def __init__(self):
|
||||
"""VGG16 construct for training backbone"""
|
||||
super(VGG16, self).__init__()
|
||||
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.classifier = VGG16Classfier()
|
||||
self.fc3 = _fc(in_channels=4096, out_channels=1000)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
:param x: shape=(B, 3, 224, 224)
|
||||
:return: logits, shape=(B, 1000)
|
||||
"""
|
||||
feature_maps = self.feature_extraction(x)
|
||||
x = self.max_pool(feature_maps)
|
||||
x = self.classifier(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
@ -0,0 +1,133 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
pretrain_config = EasyDict({
|
||||
# LR
|
||||
"base_lr": 0.0009,
|
||||
"warmup_step": 30000,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"total_epoch": 100,
|
||||
})
|
||||
finetune_config = EasyDict({
|
||||
# LR
|
||||
"base_lr": 0.0005,
|
||||
"warmup_step": 300,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"total_epoch": 50,
|
||||
})
|
||||
|
||||
# use for low case number
|
||||
config = EasyDict({
|
||||
"img_width": 960,
|
||||
"img_height": 576,
|
||||
"keep_ratio": False,
|
||||
"flip_ratio": 0.0,
|
||||
"photo_ratio": 0.0,
|
||||
"expand_ratio": 1.0,
|
||||
|
||||
# anchor
|
||||
"feature_shapes": (36, 60),
|
||||
"num_anchors": 14,
|
||||
"anchor_base": 16,
|
||||
"anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406],
|
||||
"anchor_width": [16],
|
||||
|
||||
# rpn
|
||||
"rpn_in_channels": 256,
|
||||
"rpn_feat_channels": 512,
|
||||
"rpn_loss_cls_weight": 1.0,
|
||||
"rpn_loss_reg_weight": 3.0,
|
||||
"rpn_cls_out_channels": 2,
|
||||
|
||||
# bbox_assign_sampler
|
||||
"neg_iou_thr": 0.5,
|
||||
"pos_iou_thr": 0.7,
|
||||
"min_pos_iou": 0.001,
|
||||
"num_bboxes": 30240,
|
||||
"num_gts": 256,
|
||||
"num_expected_neg": 512,
|
||||
"num_expected_pos": 256,
|
||||
|
||||
#proposal
|
||||
"activate_num_classes": 2,
|
||||
"use_sigmoid_cls": False,
|
||||
|
||||
# train proposal
|
||||
"rpn_proposal_nms_across_levels": False,
|
||||
"rpn_proposal_nms_pre": 2000,
|
||||
"rpn_proposal_nms_post": 1000,
|
||||
"rpn_proposal_max_num": 1000,
|
||||
"rpn_proposal_nms_thr": 0.7,
|
||||
"rpn_proposal_min_bbox_size": 8,
|
||||
|
||||
# rnn structure
|
||||
"input_size": 512,
|
||||
"num_step": 60,
|
||||
"rnn_batch_size": 36,
|
||||
"hidden_size": 128,
|
||||
|
||||
# training
|
||||
"warmup_mode": "linear",
|
||||
"batch_size": 1,
|
||||
"momentum": 0.9,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 10,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./",
|
||||
"use_dropout": False,
|
||||
"loss_scale": 1,
|
||||
"weight_decay": 1e-4,
|
||||
|
||||
# test proposal
|
||||
"rpn_nms_pre": 2000,
|
||||
"rpn_nms_post": 1000,
|
||||
"rpn_max_num": 1000,
|
||||
"rpn_nms_thr": 0.7,
|
||||
"rpn_min_bbox_min_size": 8,
|
||||
"test_iou_thr": 0.7,
|
||||
"test_max_per_img": 100,
|
||||
"test_batch_size": 1,
|
||||
"use_python_proposal": False,
|
||||
|
||||
# text proposal connection
|
||||
"max_horizontal_gap": 60,
|
||||
"text_proposals_min_scores": 0.7,
|
||||
"text_proposals_nms_thresh": 0.2,
|
||||
"min_v_overlaps": 0.7,
|
||||
"min_size_sim": 0.7,
|
||||
"min_ratio": 0.5,
|
||||
"line_min_score": 0.9,
|
||||
"text_proposals_width": 16,
|
||||
"min_num_proposals": 2,
|
||||
|
||||
# create dataset
|
||||
"coco_root": "",
|
||||
"coco_train_data_type": "",
|
||||
"cocotext_json": "",
|
||||
"icdar11_train_path": [],
|
||||
"icdar13_train_path": [],
|
||||
"icdar15_train_path": [],
|
||||
"icdar13_test_path": [],
|
||||
"flick_train_path": [],
|
||||
"svt_train_path": [],
|
||||
"pretrain_dataset_path": "",
|
||||
"finetune_dataset_path": "",
|
||||
"test_dataset_path": "",
|
||||
|
||||
# training dataset
|
||||
"pretraining_dataset_file": "",
|
||||
"finetune_dataset_file": ""
|
||||
})
|
@ -0,0 +1,61 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert icdar2015 dataset label"""
|
||||
import os
|
||||
import argparse
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-s', '--src_label_path', type=str, default='./',
|
||||
help='Directory containing icdar2015 train label')
|
||||
parser.add_argument('-t', '--target_label_path', type=str, default='test.xml',
|
||||
help='Directory where save the icdar2015 label after convert')
|
||||
return parser.parse_args()
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
anno_file = os.listdir(args.src_label_path)
|
||||
annos = {}
|
||||
# read
|
||||
for file in anno_file:
|
||||
gt = open(os.path.join(args.src_label_path, file), 'r', encoding='UTF-8-sig').read().splitlines()
|
||||
label_list = []
|
||||
label_name = os.path.basename(file)
|
||||
for each_label in gt:
|
||||
print(file)
|
||||
spt = each_label.split(',')
|
||||
print(spt)
|
||||
if "###" in spt[8]:
|
||||
continue
|
||||
else:
|
||||
x1 = min(int(spt[0]), int(spt[6]))
|
||||
y1 = min(int(spt[1]), int(spt[3]))
|
||||
x2 = max(int(spt[2]), int(spt[4]))
|
||||
y2 = max(int(spt[5]), int(spt[7]))
|
||||
label_list.append([x1, y1, x2, y2])
|
||||
annos[label_name] = label_list
|
||||
# write
|
||||
if not os.path.exists(args.target_label_path):
|
||||
os.makedirs(args.target_label_path)
|
||||
for label_file, pos in annos.items():
|
||||
tgt_anno_file = os.path.join(args.target_label_path, label_file)
|
||||
f = open(tgt_anno_file, 'w', encoding='UTF-8-sig')
|
||||
for tgt_label in pos:
|
||||
str_pos = str(tgt_label[0]) + ',' + str(tgt_label[1]) + ',' + str(tgt_label[2]) + ',' + str(tgt_label[3])
|
||||
f.write(str_pos)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert svt dataset label"""
|
||||
import os
|
||||
import argparse
|
||||
from xml.etree import ElementTree as ET
|
||||
import numpy as np
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||
help='Directory containing images')
|
||||
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--location_dir', type=str, default='./location',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
return parser.parse_args()
|
||||
|
||||
def xml_to_dict(xml_file, save_file=False):
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
imgs_labels = []
|
||||
|
||||
for ch in root:
|
||||
im_label = {}
|
||||
for ch01 in ch:
|
||||
if ch01.tag in "address":
|
||||
continue
|
||||
elif ch01.tag in 'taggedRectangles':
|
||||
# multiple children
|
||||
rect_list = []
|
||||
for ch02 in ch01:
|
||||
rect = {}
|
||||
rect['location'] = ch02.attrib
|
||||
rect['label'] = ch02[0].text
|
||||
rect_list.append(rect)
|
||||
im_label['rect'] = rect_list
|
||||
else:
|
||||
im_label[ch01.tag] = ch01.text
|
||||
imgs_labels.append(im_label)
|
||||
|
||||
if save_file:
|
||||
np.save("annotation_train.npy", imgs_labels)
|
||||
|
||||
return imgs_labels
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
if not os.path.exists(args.dataset_dir):
|
||||
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||
|
||||
if not os.path.exists(args.xml_file):
|
||||
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||
|
||||
if not os.path.exists(args.location_dir):
|
||||
os.makedirs(args.location_dir)
|
||||
|
||||
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||
num_images = len(ims_labels_dict)
|
||||
print("Converting annotation, {} images in total ".format(num_images))
|
||||
for i in range(num_images):
|
||||
img_label = ims_labels_dict[i]
|
||||
image_name = img_label['imageName']
|
||||
rects = img_label['rect']
|
||||
print("processing image: {}".format(image_name))
|
||||
location_file_name = os.path.join(args.location_dir, os.path.basename(image_name).replace(".jpg", ".txt"))
|
||||
f = open(location_file_name, 'w')
|
||||
for j, rect in enumerate(rects):
|
||||
rect = rects[j]
|
||||
location = rect['location']
|
||||
h = int(location['height'])
|
||||
w = int(location['width'])
|
||||
x = int(location['x'])
|
||||
y = int(location['y'])
|
||||
pos = [x, y, x+w, y+h]
|
||||
str_pos = str(pos[0]) + "," + str(pos[1]) + "," + str(pos[2]) + "," + str(pos[3])
|
||||
f.write(str_pos)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
@ -0,0 +1,177 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import division
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.config import config
|
||||
|
||||
def create_coco_label():
|
||||
"""Create image label."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
coco_root = config.coco_root
|
||||
data_type = config.coco_train_data_type
|
||||
from src.coco_text import COCO_Text
|
||||
anno_json = config.cocotext_json
|
||||
ct = COCO_Text(anno_json)
|
||||
image_ids = ct.getImgIds(imgIds=ct.train,
|
||||
catIds=[('legibility', 'legible')])
|
||||
for img_id in image_ids:
|
||||
image_info = ct.loadImgs(img_id)[0]
|
||||
file_name = image_info['file_name'][15:]
|
||||
anno_ids = ct.getAnnIds(imgIds=img_id)
|
||||
anno = ct.loadAnns(anno_ids)
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
annos = []
|
||||
im = Image.open(image_path)
|
||||
width, _ = im.size
|
||||
for label in anno:
|
||||
bbox = label["bbox"]
|
||||
bbox_width = int(bbox[2])
|
||||
if 60 * bbox_width < width:
|
||||
continue
|
||||
x1, x2 = int(bbox[0]), int(bbox[0] + bbox[2])
|
||||
y1, y2 = int(bbox[1]), int(bbox[1] + bbox[3])
|
||||
annos.append([x1, y1, x2, y2] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_anno_dataset_label(train_img_dirs, train_txt_dirs):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
# read
|
||||
img_basenames = []
|
||||
for file in os.listdir(train_img_dirs):
|
||||
# Filter git file.
|
||||
if 'gif' not in file:
|
||||
img_basenames.append(os.path.basename(file))
|
||||
img_names = []
|
||||
for item in img_basenames:
|
||||
temp1, _ = os.path.splitext(item)
|
||||
img_names.append((temp1, item))
|
||||
for img, img_basename in img_names:
|
||||
image_path = train_img_dirs + '/' + img_basename
|
||||
annos = []
|
||||
if len(img) == 6 and '_' not in img_basename:
|
||||
gt = open(train_txt_dirs + '/' + img + '.txt').read().splitlines()
|
||||
if img.isdigit() and int(img) > 1200:
|
||||
continue
|
||||
for img_each_label in gt:
|
||||
spt = img_each_label.replace(',', '').split(' ')
|
||||
if ' ' not in img_each_label:
|
||||
spt = img_each_label.split(',')
|
||||
annos.append([spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_icdar_svt_label(train_img_dir, train_txt_dir, prefix):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
img_basenames = []
|
||||
for file_name in os.listdir(train_img_dir):
|
||||
if 'gif' not in file_name:
|
||||
img_basenames.append(os.path.basename(file_name))
|
||||
img_names = []
|
||||
for item in img_basenames:
|
||||
temp1, _ = os.path.splitext(item)
|
||||
img_names.append((temp1, item))
|
||||
for img, img_basename in img_names:
|
||||
image_path = train_img_dir + '/' + img_basename
|
||||
annos = []
|
||||
file_name = prefix + img + ".txt"
|
||||
file_path = os.path.join(train_txt_dir, file_name)
|
||||
gt = open(file_path, 'r', encoding='UTF-8-sig').read().splitlines()
|
||||
if not gt:
|
||||
continue
|
||||
for img_each_label in gt:
|
||||
spt = img_each_label.replace(',', '').split(' ')
|
||||
if ' ' not in img_each_label:
|
||||
spt = img_each_label.split(',')
|
||||
annos.append([spt[0], spt[1], spt[2], spt[3]] + [1])
|
||||
if annos:
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
def create_train_dataset(dataset_type):
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if dataset_type == "pretraining":
|
||||
# pretrianing: coco, flick, icdar2013 train, icdar2015, svt
|
||||
coco_image_files, coco_anno_dict = create_coco_label()
|
||||
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
|
||||
config.flick_train_path[1])
|
||||
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
|
||||
config.icdar13_train_path[1], "gt_img_")
|
||||
icdar15_image_files, icdar15_anno_dict = create_icdar_svt_label(config.icdar15_train_path[0],
|
||||
config.icdar15_train_path[1], "gt_")
|
||||
svt_image_files, svt_anno_dict = create_icdar_svt_label(config.svt_train_path[0], config.svt_train_path[1], "")
|
||||
image_files = coco_image_files + flick_image_files + icdar13_image_files + icdar15_image_files + svt_image_files
|
||||
image_anno_dict = {**coco_anno_dict, **flick_anno_dict, \
|
||||
**icdar13_anno_dict, **icdar15_anno_dict, **svt_anno_dict}
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.pretrain_dataset_path, \
|
||||
prefix="ctpn_pretrain.mindrecord", file_num=8)
|
||||
elif dataset_type == "finetune":
|
||||
# finetune: icdar2011, icdar2013 train, flick
|
||||
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
|
||||
config.flick_train_path[1])
|
||||
icdar11_image_files, icdar11_anno_dict = create_icdar_svt_label(config.icdar11_train_path[0],
|
||||
config.icdar11_train_path[1], "gt_")
|
||||
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
|
||||
config.icdar13_train_path[1], "gt_img_")
|
||||
image_files = flick_image_files + icdar11_image_files + icdar13_image_files
|
||||
image_anno_dict = {**flick_anno_dict, **icdar11_anno_dict, **icdar13_anno_dict}
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.finetune_dataset_path, \
|
||||
prefix="ctpn_finetune.mindrecord", file_num=8)
|
||||
elif dataset_type == "test":
|
||||
# test: icdar2013 test
|
||||
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
|
||||
config.icdar13_test_path[1], "")
|
||||
image_files = icdar_test_image_files
|
||||
image_anno_dict = icdar_test_anno_dict
|
||||
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
|
||||
prefix="ctpn_test.mindrecord", file_num=1)
|
||||
else:
|
||||
print("dataset_type should be pretraining, finetune, test")
|
||||
|
||||
def data_to_mindrecord_byte_image(image_files, image_anno_dict, dst_dir, prefix="cptn_mlt.mindrecord", file_num=1):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_path = os.path.join(dst_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
|
||||
ctpn_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(ctpn_json, "ctpn_json")
|
||||
for image_name in image_files:
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
print("img name is {}, anno is {}".format(image_name, annos))
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_train_dataset("pretraining")
|
||||
create_train_dataset("finetune")
|
||||
create_train_dataset("test")
|
@ -0,0 +1,148 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CPTN network definition."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.CTPN.rpn import RPN
|
||||
from src.CTPN.anchor_generator import AnchorGenerator
|
||||
from src.CTPN.proposal_generator import Proposal
|
||||
from src.CTPN.vgg16 import VGG16FeatureExtraction
|
||||
|
||||
class BiLSTM(nn.Cell):
|
||||
"""
|
||||
Define a BiLSTM network which contains two LSTM layers
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
super(BiLSTM, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.batch_size = config.batch_size * config.rnn_batch_size
|
||||
print("batch size is {} ".format(self.batch_size))
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_step = config.num_step
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / self.hidden_size) ** 0.5
|
||||
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.w1 = Parameter(np.random.uniform(-k, k, \
|
||||
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1")
|
||||
self.w1_bw = Parameter(np.random.uniform(-k, k, \
|
||||
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw")
|
||||
|
||||
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
|
||||
self.reverse_seq = P.ReverseV2(axis=[0])
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
self.concat1 = P.Concat(axis=2)
|
||||
self.dropout = nn.Dropout(0.7)
|
||||
self.use_dropout = config.use_dropout
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
def construct(self, x):
|
||||
if self.use_dropout:
|
||||
x = self.dropout(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
bw_x = self.reverse_seq(x)
|
||||
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
|
||||
y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
|
||||
y1_bw = self.reverse_seq(y1_bw)
|
||||
output = self.concat1((y1, y1_bw))
|
||||
return output
|
||||
|
||||
class CTPN(nn.Cell):
|
||||
"""
|
||||
Define CTPN network
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config, is_training=True):
|
||||
super(CTPN, self).__init__()
|
||||
self.config = config
|
||||
self.is_training = is_training
|
||||
self.num_step = config.num_step
|
||||
self.input_size = config.input_size
|
||||
self.batch_size = config.batch_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction()
|
||||
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
|
||||
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
|
||||
# rpn block
|
||||
self.rpn_with_loss = RPN(config,
|
||||
self.batch_size,
|
||||
config.rpn_in_channels,
|
||||
config.rpn_feat_channels,
|
||||
config.num_anchors,
|
||||
config.rpn_cls_out_channels)
|
||||
self.anchor_generator = AnchorGenerator(config)
|
||||
self.featmap_size = config.feature_shapes
|
||||
self.anchor_list = self.get_anchors(self.featmap_size)
|
||||
self.proposal_generator_test = Proposal(config,
|
||||
config.test_batch_size,
|
||||
config.activate_num_classes,
|
||||
config.use_sigmoid_cls)
|
||||
self.proposal_generator_test.set_train_local(config, False)
|
||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
|
||||
# (1,3,600,900)
|
||||
x = self.vgg16_feature_extractor(img_data)
|
||||
x = self.conv(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
# (1, 512, 38, 57)
|
||||
x = self.transpose(x, (0, 2, 1, 3))
|
||||
x = self.reshape(x, (-1, self.input_size, self.num_step))
|
||||
x = self.transpose(x, (2, 0, 1))
|
||||
# (57, 38, 512)
|
||||
x = self.rnn(x)
|
||||
# (57, 38, 256)
|
||||
#x = self.cast(x, mstype.float32)
|
||||
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,
|
||||
img_metas,
|
||||
self.anchor_list,
|
||||
gt_bboxes,
|
||||
gt_labels,
|
||||
gt_valids)
|
||||
if self.training:
|
||||
return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss
|
||||
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
|
||||
return proposal, proposal_mask
|
||||
|
||||
def get_anchors(self, featmap_size):
|
||||
anchors = self.anchor_generator.grid_anchors(featmap_size)
|
||||
return Tensor(anchors, mstype.float16)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,39 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for deeptext"""
|
||||
import math
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, base_step):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.base_lr
|
||||
total_steps = int(base_step * config.total_epoch)
|
||||
warmup_steps = config.warmup_step
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
return lr
|
@ -0,0 +1,153 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FasterRcnn training network wrapper."""
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.count = 0
|
||||
self.rpn_loss_sum = 0
|
||||
self.rpn_cls_loss_sum = 0
|
||||
self.rpn_reg_loss_sum = 0
|
||||
self.rank_id = rank_id
|
||||
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = time.time()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
rpn_loss = cb_params.net_outputs[0].asnumpy()
|
||||
rpn_cls_loss = cb_params.net_outputs[1].asnumpy()
|
||||
rpn_reg_loss = cb_params.net_outputs[2].asnumpy()
|
||||
|
||||
self.count += 1
|
||||
self.rpn_loss_sum += float(rpn_loss)
|
||||
self.rpn_cls_loss_sum += float(rpn_cls_loss)
|
||||
self.rpn_reg_loss_sum += float(rpn_reg_loss)
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if self.count >= 1:
|
||||
global time_stamp_first
|
||||
time_stamp_current = time.time()
|
||||
rpn_loss = self.rpn_loss_sum / self.count
|
||||
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
|
||||
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
|
||||
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
|
||||
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"%
|
||||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
rpn_loss, rpn_cls_loss, rpn_reg_loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
class LossNet(nn.Cell):
|
||||
"""FasterRcnn loss method"""
|
||||
def construct(self, x1, x2, x3):
|
||||
return x1
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""
|
||||
Get the backbone network.
|
||||
|
||||
Returns:
|
||||
Cell, return backbone network.
|
||||
"""
|
||||
return self._backbone
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Append an optimizer to the training network after that the construct function
|
||||
can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
network_backbone (Cell): The forward network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default value is 1.0.
|
||||
reduce_flag (bool): The reduce flag. Default value is False.
|
||||
mean (bool): Allreduce method. Default value is False.
|
||||
degree (int): Device number. Default value is None.
|
||||
"""
|
||||
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.backbone = network_backbone
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
|
||||
self.reduce_flag = reduce_flag
|
||||
if reduce_flag:
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
|
||||
weights = self.weights
|
||||
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
|
||||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
|
||||
if self.reduce_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss
|
@ -0,0 +1,65 @@
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================import numpy as np
|
||||
import numpy as np
|
||||
from src.text_connector.utils import clip_boxes, fit_y
|
||||
from src.text_connector.get_successions import get_successions
|
||||
|
||||
def connect_text_lines(text_proposals, scores, size):
|
||||
"""
|
||||
Connect text lines
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
text_recs(numpy.array): Text boxes after connect.
|
||||
"""
|
||||
graph = get_successions(text_proposals, scores, size)
|
||||
text_lines = np.zeros((len(graph), 5), np.float32)
|
||||
for index, indices in enumerate(graph):
|
||||
text_line_boxes = text_proposals[list(indices)]
|
||||
x0 = np.min(text_line_boxes[:, 0])
|
||||
x1 = np.max(text_line_boxes[:, 2])
|
||||
|
||||
offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5
|
||||
|
||||
lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
|
||||
lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
|
||||
|
||||
# the score of a text line is the average score of the scores
|
||||
# of all text proposals contained in the text line
|
||||
score = scores[list(indices)].sum() / float(len(indices))
|
||||
|
||||
text_lines[index, 0] = x0
|
||||
text_lines[index, 1] = min(lt_y, rt_y)
|
||||
text_lines[index, 2] = x1
|
||||
text_lines[index, 3] = max(lb_y, rb_y)
|
||||
text_lines[index, 4] = score
|
||||
|
||||
text_lines = clip_boxes(text_lines, size)
|
||||
|
||||
text_recs = np.zeros((len(text_lines), 9), np.float)
|
||||
index = 0
|
||||
for line in text_lines:
|
||||
xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3]
|
||||
text_recs[index, 0] = xmin
|
||||
text_recs[index, 1] = ymin
|
||||
text_recs[index, 2] = xmax
|
||||
text_recs[index, 3] = ymax
|
||||
text_recs[index, 4] = line[4]
|
||||
index = index + 1
|
||||
return text_recs
|
@ -0,0 +1,73 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.text_connector.utils import nms
|
||||
from src.text_connector.connect_text_lines import connect_text_lines
|
||||
|
||||
def filter_proposal(proposals, scores):
|
||||
"""
|
||||
Filter text proposals
|
||||
|
||||
Args:
|
||||
proposals(numpy.array): Text proposals.
|
||||
Returns:
|
||||
proposals(numpy.array): Text proposals after filter.
|
||||
"""
|
||||
inds = np.where(scores > config.text_proposals_min_scores)[0]
|
||||
keep_proposals = proposals[inds]
|
||||
keep_scores = scores[inds]
|
||||
sorted_inds = np.argsort(keep_scores.ravel())[::-1]
|
||||
keep_proposals, keep_scores = keep_proposals[sorted_inds], keep_scores[sorted_inds]
|
||||
nms_inds = nms(np.hstack((keep_proposals, keep_scores)), config.text_proposals_nms_thresh)
|
||||
keep_proposals, keep_scores = keep_proposals[nms_inds], keep_scores[nms_inds]
|
||||
return keep_proposals, keep_scores
|
||||
|
||||
def filter_boxes(boxes):
|
||||
"""
|
||||
Filter text boxes
|
||||
|
||||
Args:
|
||||
boxes(numpy.array): Text boxes.
|
||||
Returns:
|
||||
boxes(numpy.array): Text boxes after filter.
|
||||
"""
|
||||
heights = np.zeros((len(boxes), 1), np.float)
|
||||
widths = np.zeros((len(boxes), 1), np.float)
|
||||
scores = np.zeros((len(boxes), 1), np.float)
|
||||
index = 0
|
||||
for box in boxes:
|
||||
widths[index] = abs(box[2] - box[0])
|
||||
heights[index] = abs(box[3] - box[1])
|
||||
scores[index] = abs(box[4])
|
||||
index += 1
|
||||
return np.where((widths / heights > config.min_ratio) & (scores > config.line_min_score) &\
|
||||
(widths > (config.text_proposals_width * config.min_num_proposals)))[0]
|
||||
|
||||
def detect(text_proposals, scores, size):
|
||||
"""
|
||||
Detect text boxes
|
||||
|
||||
Args:
|
||||
text_proposals(numpy.array): Predict text proposals.
|
||||
scores(numpy.array): Bbox predicts scores.
|
||||
size(numpy.array): Image size.
|
||||
Returns:
|
||||
boxes(numpy.array): Text boxes after connect.
|
||||
"""
|
||||
keep_proposals, keep_scores = filter_proposal(text_proposals, scores)
|
||||
connect_boxes = connect_text_lines(keep_proposals, keep_scores, size)
|
||||
boxes = connect_boxes[filter_boxes(connect_boxes)]
|
||||
return boxes
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue