!12121 Add CTPN network

From: @qujianwei
Reviewed-by: 
Signed-off-by:
pull/12121/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f0a9cb7c20

File diff suppressed because it is too large Load Diff

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for CTPN"""
import os
import argparse
import time
import numpy as np
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.ctpn import CTPN
from src.config import config
from src.dataset import create_ctpn_dataset
from src.text_connector.detector import detect
set_seed(1)
parser = argparse.ArgumentParser(description="CTPN evaluation")
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.")
parser.add_argument("--image_path", type=str, default="", help="Image path.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
"""ctpn infer."""
print("ckpt path is {}".format(ckpt_path))
ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False)
config.batch_size = config.test_batch_size
total = ds.get_dataset_size()
print("*************total dataset size is {}".format(total))
net = CTPN(config, is_training=False)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
eval_iter = 0
print("\n========================================\n")
print("Processing, please wait a moment.")
img_basenames = []
output_dir = os.path.join(os.getcwd(), "submit")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
for file in os.listdir(img_dir):
img_basenames.append(os.path.basename(file))
for data in ds.create_dict_iterator():
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
start = time.time()
# run net
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
gt_bboxes = gt_bboxes.asnumpy()
gt_labels = gt_labels.asnumpy()
gt_num = gt_num.asnumpy().astype(bool)
end = time.time()
proposal = output[0]
proposal_mask = output[1]
print("start to draw pic")
for j in range(config.test_batch_size):
img = img_basenames[config.test_batch_size * eval_iter + j]
all_box_tmp = proposal[j].asnumpy()
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
using_boxes_mask = all_box_tmp * all_mask_tmp
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
scores = using_boxes_mask[:, 4].astype(np.float32)
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
from PIL import Image, ImageDraw
im = Image.open(img_dir + '/' + img)
draw = ImageDraw.Draw(im)
image_h = img_metas.asnumpy()[j][2]
image_w = img_metas.asnumpy()[j][3]
gt_boxs = gt_bboxes[j][gt_num[j], :]
for gt_box in gt_boxs:
gt_x1 = gt_box[0] / image_w
gt_y1 = gt_box[1] / image_h
gt_x2 = gt_box[2] / image_w
gt_y2 = gt_box[3] / image_h
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
fill='green', width=2)
file_name = "res_" + img.replace("jpg", "txt")
output_file = os.path.join(output_dir, file_name)
f = open(output_file, 'w')
for bbox in bboxes:
x1 = bbox[0] / image_w
y1 = bbox[1] / image_h
x2 = bbox[2] / image_w
y2 = bbox[3] / image_h
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
f.write(str_tmp)
f.write("\n")
f.close()
im.save(img)
percent = round(eval_iter / total * 100, 2)
eval_iter = eval_iter + 1
print("Iter {} cost time {}".format(eval_iter, end - start))
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
if __name__ == '__main__':
ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path)

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
for submit_file in "submit"*.zip
do
echo "eval result for ${submit_file}"
python script.py g=gt.zip s=${submit_file} o=./
echo -e ".\n"
done

@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
TASK_TYPE=$2
PATH2=$(get_real_path $3)
echo $PATH2
if [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --task_type=$TASK_TYPE --pre_trained=$PATH2 &> log &
cd ..
done

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
IMAGE_PATH=$(get_real_path $1)
DATASET_PATH=$(get_real_path $2)
CHECKPOINT_PATH=$(get_real_path $3)
echo $IMAGE_PATH
echo $DATASET_PATH
echo $CHECKPOINT_PATH
if [ ! -d $IMAGE_PATH ]
then
echo "error: IMAGE_PATH=$PATH1 is not a path"
exit 1
fi
if [ ! -f $DATASET_PATH ]
then
echo "error: CHECKPOINT_PATH=$DATASET_PATH is not a path"
exit 1
fi
if [ ! -d $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
for file in "${CHECKPOINT_PATH}"/*.ckpt
do
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval
env > env.log
CHECKPOINT_FILE_PATH=$file
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
python eval.py --device_id=$DEVICE_ID --image_path=$IMAGE_PATH --dataset_path=$DATASET_PATH --checkpoint_path=$CHECKPOINT_FILE_PATH &> log
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
cd ./submit
file_base_name=$(basename $file)
zip -r ../../submit_${file_base_name%.*}.zip *.txt
cd ../../
done

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
TASK_TYPE=$1
PRETRAINED_PATH=$(get_real_path $2)
echo $PRETRAINED_PATH
if [ ! -f $PRETRAINED_PATH ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_PATH is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH &> log &
cd ..

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
class BoundingBoxDecode(nn.Cell):
"""
BoundintBox Decoder.
Returns:
pred_box(Tensor): decoder bounding boxes.
"""
def __init__(self):
super(BoundingBoxDecode, self).__init__()
self.split = P.Split(axis=1, output_num=4)
self.ones = 1.0
self.half = 0.5
self.log = P.Log()
self.exp = P.Exp()
self.concat = P.Concat(axis=1)
def construct(self, bboxes, deltas):
"""
boxes(Tensor): boundingbox.
deltas(Tensor): delta between boundingboxs and anchors.
"""
x1, y1, x2, y2 = self.split(bboxes)
width = x2 - x1 + self.ones
height = y2 - y1 + self.ones
ctr_x = x1 + self.half * width
ctr_y = y1 + self.half * height
_, dy, _, dh = self.split(deltas)
pred_ctr_x = ctr_x
pred_ctr_y = dy * height + ctr_y
pred_w = width
pred_h = self.exp(dh) * height
x1 = pred_ctr_x - self.half * pred_w
y1 = pred_ctr_y - self.half * pred_h
x2 = pred_ctr_x + self.half * pred_w
y2 = pred_ctr_y + self.half * pred_h
pred_box = self.concat((x1, y1, x2, y2))
return pred_box

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
class BoundingBoxEncode(nn.Cell):
"""
BoundintBox Decoder.
Returns:
pred_box(Tensor): decoder bounding boxes.
"""
def __init__(self):
super(BoundingBoxEncode, self).__init__()
self.split = P.Split(axis=1, output_num=4)
self.ones = 1.0
self.half = 0.5
self.log = P.Log()
self.concat = P.Concat(axis=1)
def construct(self, anchor_box, gt_box):
"""
boxes(Tensor): boundingbox.
deltas(Tensor): delta between boundingboxs and anchors.
"""
x1, y1, x2, y2 = self.split(anchor_box)
width = x2 - x1 + self.ones
height = y2 - y1 + self.ones
ctr_x = x1 + self.half * width
ctr_y = y1 + self.half * height
gt_x1, gt_y1, gt_x2, gt_y2 = self.split(gt_box)
gt_width = gt_x2 - gt_x1 + self.ones
gt_height = gt_y2 - gt_y1 + self.ones
ctr_gt_x = gt_x1 + self.half * gt_width
ctr_gt_y = gt_y1 + self.half * gt_height
target_dx = (ctr_gt_x - ctr_x) / width
target_dy = (ctr_gt_y - ctr_y) / height
dw = gt_width / width
dh = gt_height / height
target_dw = self.log(dw)
target_dh = self.log(dh)
deltas = self.concat((target_dx, target_dy, target_dw, target_dh))
return deltas

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for FasterRcnn."""
def __init__(self, config):
"""Anchor generator init method."""
self.base_size = config.anchor_base
self.num_anchor = config.num_anchors
self.anchor_height = config.anchor_height
self.anchor_width = config.anchor_width
self.size = self.gen_anchor_size()
self.base_anchors = self.gen_base_anchors()
def gen_base_anchors(self):
"""Generate a single anchor."""
base_anchor = np.array([0, 0, self.base_size - 1, self.base_size - 1], np.int32)
anchors = np.zeros((len(self.size), 4), np.int32)
index = 0
for h, w in self.size:
anchors[index] = self.scale_anchor(base_anchor, h, w)
index += 1
return anchors
def gen_anchor_size(self):
"""Generate a list of anchor size"""
size = []
for width in self.anchor_width:
for height in self.anchor_height:
size.append((height, width))
return size
def scale_anchor(self, anchor, h, w):
x_ctr = (anchor[0] + anchor[2]) * 0.5
y_ctr = (anchor[1] + anchor[3]) * 0.5
scaled_anchor = anchor.copy()
scaled_anchor[0] = x_ctr - w / 2 # xmin
scaled_anchor[2] = x_ctr + w / 2 # xmax
scaled_anchor[1] = y_ctr - h / 2 # ymin
scaled_anchor[3] = y_ctr + h / 2 # ymax
return scaled_anchor
def _meshgrid(self, x, y):
"""Generate grid."""
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
yy = np.repeat(y, len(x))
return xx, yy
def grid_anchors(self, featmap_size, stride=16):
"""Generate anchor list."""
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = shifts.astype(base_anchors.dtype)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.reshape(-1, 4)
return all_anchors

@ -0,0 +1,152 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn positive and negative sample screening for RPN."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from src.CTPN.BoundingBoxEncode import BoundingBoxEncode
class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler definition.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSample(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
self.zero_thr = Tensor(0.0, mstype.float16)
self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
self.num_expected_pos = cfg.num_expected_pos
self.num_expected_neg = cfg.num_expected_neg
self.add_gt_as_proposals = add_gt_as_proposals
if self.add_gt_as_proposals:
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = BoundingBoxEncode()
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
self.print = P.Print()
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
self.less(max_overlaps_w_gt, self.neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
assigned_gt_inds4 = assigned_gt_inds3
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
pos_bboxes_ = self.gatherND(bboxes, pos_index)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
total_index = self.concat((pos_index, neg_index))
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
labels_total, self.cast(label_weights_total, mstype.bool_)

@ -0,0 +1,190 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn proposal generator."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from src.CTPN.BoundingBoxDecode import BoundingBoxDecode
class Proposal(nn.Cell):
"""
Proposal subnet.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_classes (int) - Class number.
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
Returns:
Tuple, tuple of output tensor,(proposal, mask).
Examples:
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
"""
def __init__(self,
config,
batch_size,
num_classes,
use_sigmoid_cls,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)
):
super(Proposal, self).__init__()
cfg = config
self.batch_size = batch_size
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = config.use_sigmoid_cls
if self.use_sigmoid_cls:
self.cls_out_channels = 1
self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else:
self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)
if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))
self.num_pre = cfg.rpn_proposal_nms_pre
self.min_box_size = cfg.rpn_proposal_min_bbox_size
self.nms_thr = cfg.rpn_proposal_nms_thr
self.nms_post = cfg.rpn_proposal_nms_post
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
self.max_num = cfg.rpn_proposal_max_num
# Op Define
self.squeeze = P.Squeeze()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.feature_shapes = cfg.feature_shapes
self.transpose_shape = (1, 2, 0)
self.decode = BoundingBoxDecode()
self.nms = P.NMSWithMask(self.nms_thr)
self.concat_axis0 = P.Concat(axis=0)
self.concat_axis1 = P.Concat(axis=1)
self.split = P.Split(axis=1, output_num=5)
self.min = P.Minimum()
self.gatherND = P.GatherNd()
self.slice = P.Slice()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
self.tile = P.Tile()
self.set_train_local(config, training=True)
self.multi_10 = Tensor(10.0, mstype.float16)
def set_train_local(self, config, training=False):
"""Set training flag."""
self.training_local = training
cfg = config
self.topK_stage1 = ()
self.topK_shape = ()
total_max_topk_input = 0
if not self.training_local:
self.num_pre = cfg.rpn_nms_pre
self.min_box_size = cfg.rpn_min_bbox_min_size
self.nms_thr = cfg.rpn_nms_thr
self.nms_post = cfg.rpn_nms_post
self.max_num = cfg.rpn_max_num
k_num = self.num_pre
total_max_topk_input = k_num
self.topK_stage1 = k_num
self.topK_shape = (k_num, 1)
self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65536.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
self.shape = P.Shape()
self.print = P.Print()
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
masks_tuple = ()
for img_id in range(self.batch_size):
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[img_id:img_id+1:1, ::, ::, ::])
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[img_id:img_id+1:1, ::, ::, ::])
proposals, masks = self.get_bboxes_single(rpn_cls_score_i, rpn_bbox_pred_i, anchor_list)
proposals_tuple += (proposals,)
masks_tuple += (masks,)
return proposals_tuple, masks_tuple
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
"""Get proposal boundingbox."""
mlvl_proposals = ()
mlvl_mask = ()
rpn_cls_score = self.transpose(cls_scores, self.transpose_shape)
rpn_bbox_pred = self.transpose(bbox_preds, self.transpose_shape)
anchors = mlvl_anchors
# (H, W, A*2)
rpn_cls_score_shape = self.shape(rpn_cls_score)
rpn_cls_score = self.reshape(rpn_cls_score, (rpn_cls_score_shape[0], \
rpn_cls_score_shape[1], -1, self.cls_out_channels))
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
if self.use_sigmoid_cls:
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score), mstype.float16)
else:
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 1]), mstype.float16)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.num_pre)
topk_inds = self.reshape(topk_inds, self.topK_shape)
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape)))
proposals, _, mask_valid = self.nms(proposals_decode)
mlvl_proposals = mlvl_proposals + (proposals,)
mlvl_mask = mlvl_mask + (mask_valid,)
proposals = self.concat_axis0(mlvl_proposals)
masks = self.concat_axis0(mlvl_mask)
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
proposals = self.gatherND(proposals, topk_inds)
masks = self.gatherND(masks, topk_inds)
return proposals, masks

@ -0,0 +1,228 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RPN for fasterRCNN"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops import functional as F
from src.CTPN.bbox_assign_sample import BboxAssignSample
class RpnRegClsBlock(nn.Cell):
"""
Rpn reg cls block for rpn layer
Args:
config(EasyDict) - Network construction config.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
Tensor, output tensor.
"""
def __init__(self,
config,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RpnRegClsBlock, self).__init__()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.shape = (-1, 2*config.hidden_size)
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
self.transpose = P.Transpose()
self.print = P.Print()
self.dropout = nn.Dropout(0.8)
def construct(self, x):
x = self.reshape(x, self.shape)
x = self.lstm_fc(x)
x1 = self.rpn_cls(x)
x1 = self.reshape(x1, self.shape1)
x1 = self.transpose(x1, (2, 1, 0))
x1 = self.reshape(x1, self.shape2)
x1 = self.transpose(x1, (1, 0, 2, 3))
x2 = self.rpn_reg(x)
x2 = self.reshape(x2, self.shape1)
x2 = self.transpose(x2, (2, 1, 0))
x2 = self.reshape(x2, self.shape2)
x2 = self.transpose(x2, (1, 0, 2, 3))
return x1, x2
class RPN(nn.Cell):
"""
ROI proposal network..
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
Tuple, tuple of output tensor.
Examples:
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
num_anchors=3, cls_out_channels=512)
"""
def __init__(self,
config,
batch_size,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
self.cfg = config
self.num_bboxes = cfg_rpn.num_bboxes
self.feature_anchor_shape = cfg_rpn.feature_shapes
self.feature_anchor_shape = self.feature_anchor_shape[0] * \
self.feature_anchor_shape[1] * num_anchors * batch_size
self.num_anchors = num_anchors
self.batch_size = batch_size
self.test_batch_size = cfg_rpn.test_batch_size
self.num_layers = 1
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
self.use_sigmoid_cls = config.use_sigmoid_cls
if config.use_sigmoid_cls:
self.reshape_shape_cls = (-1,)
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
cls_out_channels = 1
else:
self.reshape_shape_cls = (-1, cls_out_channels)
self.loss_cls = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none")
self.rpn_convs_list = self._make_rpn_layer(self.num_layers, in_channels, feat_channels,\
num_anchors, cls_out_channels)
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=0)
self.fill = P.Fill()
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
self.trans_shape = (0, 2, 3, 1)
self.reshape_shape_reg = (-1, 4)
self.softmax = nn.Softmax()
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
self.num_bboxes = cfg_rpn.num_bboxes
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
self.CheckValid = P.CheckValid()
self.sum_loss = P.ReduceSum()
self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
self.print = P.Print()
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
"""
make rpn layer for rpn proposal network
Args:
num_layers (int) - layer num.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
List, list of RpnRegClsBlock cells.
"""
rpn_layer = RpnRegClsBlock(self.cfg, in_channels, feat_channels, num_anchors, cls_out_channels)
return rpn_layer
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
'''
inputs(Tensor): Inputs tensor from lstm.
img_metas(Tensor): Image shape.
anchor_list(Tensor): Total anchor list.
gt_labels(Tensor): Ground truth labels.
gt_valids(Tensor): Whether ground truth is valid.
'''
rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs)
rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape)
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls)
rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape)
rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg)
output = ()
bbox_targets = ()
bbox_weights = ()
labels = ()
label_weights = ()
if self.training:
for i in range(self.batch_size):
valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\
mstype.int32)
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
gt_labels_i,
self.cast(valid_flag_list,
mstype.bool_),
anchor_list, gt_valids_i)
bbox_weight = self.cast(bbox_weight, mstype.float16)
label_weight = self.cast(label_weight, mstype.float16)
bbox_targets += (bbox_target,)
bbox_weights += (bbox_weight,)
labels += (label,)
label_weights += (label_weight,)
bbox_target_with_batchsize = self.concat(bbox_targets)
bbox_weight_with_batchsize = self.concat(bbox_weights)
label_with_batchsize = self.concat(labels)
label_weight_with_batchsize = self.concat(label_weights)
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
rpn_cls_score = self.cast(rpn_cls_score, mstype.float32)
if self.use_sigmoid_cls:
label_ = self.cast(label_, mstype.float32)
loss_cls = self.loss_cls(rpn_cls_score, label_)
loss_cls = loss_cls * label_weight_
loss_cls = self.sum_loss(loss_cls, (0,)) / self.num_expected_total
rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32)
bbox_target_ = self.cast(bbox_target_, mstype.float32)
loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_)
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)), (1, 4))
loss_reg = loss_reg * bbox_weight_
loss_reg = self.sum_loss(loss_reg, (1,))
loss_reg = self.sum_loss(loss_reg, (0,)) / self.num_expected_total
loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg
output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori, loss_cls, loss_reg)
else:
output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori, self.placeh1, self.placeh1)
return output

@ -0,0 +1,177 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
def _weight_variable(shape, factor=0.01):
''''weight initialize'''
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
"""Conv2D wrapper."""
weights = 'ones'
layers = []
conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
if not weights_update:
conv.weight.requires_grad = False
layers += [conv]
layers += [_BatchNorm2dInit(out_channels)]
return nn.SequentialCell(layers)
def _fc(in_channels, out_channels):
'''full connection layer'''
weight = _weight_variable((out_channels, in_channels))
bias = _weight_variable((out_channels,))
return nn.Dense(in_channels, out_channels, weight, bias)
class VGG16FeatureExtraction(nn.Cell):
def __init__(self, weights_update=False):
"""
VGG16 feature extraction
Args:
weights_updata(bool): whether update weights for top two layers, default is False.
"""
super(VGG16FeatureExtraction, self).__init__()
self.relu = nn.ReLU()
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3,\
padding=1, weights_update=weights_update)
self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1)
self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1)
self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1)
self.cast = P.Cast()
def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return:
"""
x = self.cast(x, mstype.float32)
x = self.conv1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv2_1(x)
x = self.relu(x)
x = self.conv2_2(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv3_1(x)
x = self.relu(x)
x = self.conv3_2(x)
x = self.relu(x)
x = self.conv3_3(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.relu(x)
x = self.conv4_3(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.conv5_1(x)
x = self.relu(x)
x = self.conv5_2(x)
x = self.relu(x)
x = self.conv5_3(x)
x = self.relu(x)
return x
class VGG16Classfier(nn.Cell):
def __init__(self):
"""VGG16 classfier structure"""
super(VGG16Classfier, self).__init__()
self.flatten = P.Flatten()
self.relu = nn.ReLU()
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
self.fc2 = _fc(in_channels=4096, out_channels=4096)
self.batch_size = 32
self.reshape = P.Reshape()
def construct(self, x):
"""
:param x: shape=(B, 512, 7, 7)
:return:
"""
x = self.reshape(x, (self.batch_size, 7*7*512))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
return x
class VGG16(nn.Cell):
def __init__(self):
"""VGG16 construct for training backbone"""
super(VGG16, self).__init__()
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = VGG16Classfier()
self.fc3 = _fc(in_channels=4096, out_channels=1000)
def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return: logits, shape=(B, 1000)
"""
feature_maps = self.feature_extraction(x)
x = self.max_pool(feature_maps)
x = self.classifier(x)
x = self.fc3(x)
return x

@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network parameters."""
from easydict import EasyDict
pretrain_config = EasyDict({
# LR
"base_lr": 0.0009,
"warmup_step": 30000,
"warmup_ratio": 1/3.0,
"total_epoch": 100,
})
finetune_config = EasyDict({
# LR
"base_lr": 0.0005,
"warmup_step": 300,
"warmup_ratio": 1/3.0,
"total_epoch": 50,
})
# use for low case number
config = EasyDict({
"img_width": 960,
"img_height": 576,
"keep_ratio": False,
"flip_ratio": 0.0,
"photo_ratio": 0.0,
"expand_ratio": 1.0,
# anchor
"feature_shapes": (36, 60),
"num_anchors": 14,
"anchor_base": 16,
"anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406],
"anchor_width": [16],
# rpn
"rpn_in_channels": 256,
"rpn_feat_channels": 512,
"rpn_loss_cls_weight": 1.0,
"rpn_loss_reg_weight": 3.0,
"rpn_cls_out_channels": 2,
# bbox_assign_sampler
"neg_iou_thr": 0.5,
"pos_iou_thr": 0.7,
"min_pos_iou": 0.001,
"num_bboxes": 30240,
"num_gts": 256,
"num_expected_neg": 512,
"num_expected_pos": 256,
#proposal
"activate_num_classes": 2,
"use_sigmoid_cls": False,
# train proposal
"rpn_proposal_nms_across_levels": False,
"rpn_proposal_nms_pre": 2000,
"rpn_proposal_nms_post": 1000,
"rpn_proposal_max_num": 1000,
"rpn_proposal_nms_thr": 0.7,
"rpn_proposal_min_bbox_size": 8,
# rnn structure
"input_size": 512,
"num_step": 60,
"rnn_batch_size": 36,
"hidden_size": 128,
# training
"warmup_mode": "linear",
"batch_size": 1,
"momentum": 0.9,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"keep_checkpoint_max": 5,
"save_checkpoint_path": "./",
"use_dropout": False,
"loss_scale": 1,
"weight_decay": 1e-4,
# test proposal
"rpn_nms_pre": 2000,
"rpn_nms_post": 1000,
"rpn_max_num": 1000,
"rpn_nms_thr": 0.7,
"rpn_min_bbox_min_size": 8,
"test_iou_thr": 0.7,
"test_max_per_img": 100,
"test_batch_size": 1,
"use_python_proposal": False,
# text proposal connection
"max_horizontal_gap": 60,
"text_proposals_min_scores": 0.7,
"text_proposals_nms_thresh": 0.2,
"min_v_overlaps": 0.7,
"min_size_sim": 0.7,
"min_ratio": 0.5,
"line_min_score": 0.9,
"text_proposals_width": 16,
"min_num_proposals": 2,
# create dataset
"coco_root": "",
"coco_train_data_type": "",
"cocotext_json": "",
"icdar11_train_path": [],
"icdar13_train_path": [],
"icdar15_train_path": [],
"icdar13_test_path": [],
"flick_train_path": [],
"svt_train_path": [],
"pretrain_dataset_path": "",
"finetune_dataset_path": "",
"test_dataset_path": "",
# training dataset
"pretraining_dataset_file": "",
"finetune_dataset_file": ""
})

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convert icdar2015 dataset label"""
import os
import argparse
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-s', '--src_label_path', type=str, default='./',
help='Directory containing icdar2015 train label')
parser.add_argument('-t', '--target_label_path', type=str, default='test.xml',
help='Directory where save the icdar2015 label after convert')
return parser.parse_args()
def convert():
args = init_args()
anno_file = os.listdir(args.src_label_path)
annos = {}
# read
for file in anno_file:
gt = open(os.path.join(args.src_label_path, file), 'r', encoding='UTF-8-sig').read().splitlines()
label_list = []
label_name = os.path.basename(file)
for each_label in gt:
print(file)
spt = each_label.split(',')
print(spt)
if "###" in spt[8]:
continue
else:
x1 = min(int(spt[0]), int(spt[6]))
y1 = min(int(spt[1]), int(spt[3]))
x2 = max(int(spt[2]), int(spt[4]))
y2 = max(int(spt[5]), int(spt[7]))
label_list.append([x1, y1, x2, y2])
annos[label_name] = label_list
# write
if not os.path.exists(args.target_label_path):
os.makedirs(args.target_label_path)
for label_file, pos in annos.items():
tgt_anno_file = os.path.join(args.target_label_path, label_file)
f = open(tgt_anno_file, 'w', encoding='UTF-8-sig')
for tgt_label in pos:
str_pos = str(tgt_label[0]) + ',' + str(tgt_label[1]) + ',' + str(tgt_label[2]) + ',' + str(tgt_label[3])
f.write(str_pos)
f.write("\n")
f.close()
if __name__ == "__main__":
convert()

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convert svt dataset label"""
import os
import argparse
from xml.etree import ElementTree as ET
import numpy as np
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
help='Directory containing images')
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
help='Directory where character dictionaries for the dataset were stored')
parser.add_argument('-o', '--location_dir', type=str, default='./location',
help='Directory where ord map dictionaries for the dataset were stored')
return parser.parse_args()
def xml_to_dict(xml_file, save_file=False):
tree = ET.parse(xml_file)
root = tree.getroot()
imgs_labels = []
for ch in root:
im_label = {}
for ch01 in ch:
if ch01.tag in "address":
continue
elif ch01.tag in 'taggedRectangles':
# multiple children
rect_list = []
for ch02 in ch01:
rect = {}
rect['location'] = ch02.attrib
rect['label'] = ch02[0].text
rect_list.append(rect)
im_label['rect'] = rect_list
else:
im_label[ch01.tag] = ch01.text
imgs_labels.append(im_label)
if save_file:
np.save("annotation_train.npy", imgs_labels)
return imgs_labels
def convert():
args = init_args()
if not os.path.exists(args.dataset_dir):
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
if not os.path.exists(args.xml_file):
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
if not os.path.exists(args.location_dir):
os.makedirs(args.location_dir)
ims_labels_dict = xml_to_dict(args.xml_file, True)
num_images = len(ims_labels_dict)
print("Converting annotation, {} images in total ".format(num_images))
for i in range(num_images):
img_label = ims_labels_dict[i]
image_name = img_label['imageName']
rects = img_label['rect']
print("processing image: {}".format(image_name))
location_file_name = os.path.join(args.location_dir, os.path.basename(image_name).replace(".jpg", ".txt"))
f = open(location_file_name, 'w')
for j, rect in enumerate(rects):
rect = rects[j]
location = rect['location']
h = int(location['height'])
w = int(location['width'])
x = int(location['x'])
y = int(location['y'])
pos = [x, y, x+w, y+h]
str_pos = str(pos[0]) + "," + str(pos[1]) + "," + str(pos[2]) + "," + str(pos[3])
f.write(str_pos)
f.write("\n")
f.close()
if __name__ == "__main__":
convert()

@ -0,0 +1,177 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from __future__ import division
import os
import numpy as np
from PIL import Image
from mindspore.mindrecord import FileWriter
from src.config import config
def create_coco_label():
"""Create image label."""
image_files = []
image_anno_dict = {}
coco_root = config.coco_root
data_type = config.coco_train_data_type
from src.coco_text import COCO_Text
anno_json = config.cocotext_json
ct = COCO_Text(anno_json)
image_ids = ct.getImgIds(imgIds=ct.train,
catIds=[('legibility', 'legible')])
for img_id in image_ids:
image_info = ct.loadImgs(img_id)[0]
file_name = image_info['file_name'][15:]
anno_ids = ct.getAnnIds(imgIds=img_id)
anno = ct.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name)
annos = []
im = Image.open(image_path)
width, _ = im.size
for label in anno:
bbox = label["bbox"]
bbox_width = int(bbox[2])
if 60 * bbox_width < width:
continue
x1, x2 = int(bbox[0]), int(bbox[0] + bbox[2])
y1, y2 = int(bbox[1]), int(bbox[1] + bbox[3])
annos.append([x1, y1, x2, y2] + [1])
if annos:
image_anno_dict[image_path] = np.array(annos)
image_files.append(image_path)
return image_files, image_anno_dict
def create_anno_dataset_label(train_img_dirs, train_txt_dirs):
image_files = []
image_anno_dict = {}
# read
img_basenames = []
for file in os.listdir(train_img_dirs):
# Filter git file.
if 'gif' not in file:
img_basenames.append(os.path.basename(file))
img_names = []
for item in img_basenames:
temp1, _ = os.path.splitext(item)
img_names.append((temp1, item))
for img, img_basename in img_names:
image_path = train_img_dirs + '/' + img_basename
annos = []
if len(img) == 6 and '_' not in img_basename:
gt = open(train_txt_dirs + '/' + img + '.txt').read().splitlines()
if img.isdigit() and int(img) > 1200:
continue
for img_each_label in gt:
spt = img_each_label.replace(',', '').split(' ')
if ' ' not in img_each_label:
spt = img_each_label.split(',')
annos.append([spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1])
if annos:
image_anno_dict[image_path] = np.array(annos)
image_files.append(image_path)
return image_files, image_anno_dict
def create_icdar_svt_label(train_img_dir, train_txt_dir, prefix):
image_files = []
image_anno_dict = {}
img_basenames = []
for file_name in os.listdir(train_img_dir):
if 'gif' not in file_name:
img_basenames.append(os.path.basename(file_name))
img_names = []
for item in img_basenames:
temp1, _ = os.path.splitext(item)
img_names.append((temp1, item))
for img, img_basename in img_names:
image_path = train_img_dir + '/' + img_basename
annos = []
file_name = prefix + img + ".txt"
file_path = os.path.join(train_txt_dir, file_name)
gt = open(file_path, 'r', encoding='UTF-8-sig').read().splitlines()
if not gt:
continue
for img_each_label in gt:
spt = img_each_label.replace(',', '').split(' ')
if ' ' not in img_each_label:
spt = img_each_label.split(',')
annos.append([spt[0], spt[1], spt[2], spt[3]] + [1])
if annos:
image_anno_dict[image_path] = np.array(annos)
image_files.append(image_path)
return image_files, image_anno_dict
def create_train_dataset(dataset_type):
image_files = []
image_anno_dict = {}
if dataset_type == "pretraining":
# pretrianing: coco, flick, icdar2013 train, icdar2015, svt
coco_image_files, coco_anno_dict = create_coco_label()
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
config.flick_train_path[1])
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
config.icdar13_train_path[1], "gt_img_")
icdar15_image_files, icdar15_anno_dict = create_icdar_svt_label(config.icdar15_train_path[0],
config.icdar15_train_path[1], "gt_")
svt_image_files, svt_anno_dict = create_icdar_svt_label(config.svt_train_path[0], config.svt_train_path[1], "")
image_files = coco_image_files + flick_image_files + icdar13_image_files + icdar15_image_files + svt_image_files
image_anno_dict = {**coco_anno_dict, **flick_anno_dict, \
**icdar13_anno_dict, **icdar15_anno_dict, **svt_anno_dict}
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.pretrain_dataset_path, \
prefix="ctpn_pretrain.mindrecord", file_num=8)
elif dataset_type == "finetune":
# finetune: icdar2011, icdar2013 train, flick
flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0],
config.flick_train_path[1])
icdar11_image_files, icdar11_anno_dict = create_icdar_svt_label(config.icdar11_train_path[0],
config.icdar11_train_path[1], "gt_")
icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0],
config.icdar13_train_path[1], "gt_img_")
image_files = flick_image_files + icdar11_image_files + icdar13_image_files
image_anno_dict = {**flick_anno_dict, **icdar11_anno_dict, **icdar13_anno_dict}
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.finetune_dataset_path, \
prefix="ctpn_finetune.mindrecord", file_num=8)
elif dataset_type == "test":
# test: icdar2013 test
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
config.icdar13_test_path[1], "")
image_files = icdar_test_image_files
image_anno_dict = icdar_test_anno_dict
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
prefix="ctpn_test.mindrecord", file_num=1)
else:
print("dataset_type should be pretraining, finetune, test")
def data_to_mindrecord_byte_image(image_files, image_anno_dict, dst_dir, prefix="cptn_mlt.mindrecord", file_num=1):
"""Create MindRecord file."""
mindrecord_path = os.path.join(dst_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
ctpn_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]},
}
writer.add_schema(ctpn_json, "ctpn_json")
for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
print("img name is {}, anno is {}".format(image_name, annos))
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
if __name__ == "__main__":
create_train_dataset("pretraining")
create_train_dataset("finetune")
create_train_dataset("test")

@ -0,0 +1,148 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CPTN network definition."""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from src.CTPN.rpn import RPN
from src.CTPN.anchor_generator import AnchorGenerator
from src.CTPN.proposal_generator import Proposal
from src.CTPN.vgg16 import VGG16FeatureExtraction
class BiLSTM(nn.Cell):
"""
Define a BiLSTM network which contains two LSTM layers
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config, is_training=True):
super(BiLSTM, self).__init__()
self.is_training = is_training
self.batch_size = config.batch_size * config.rnn_batch_size
print("batch size is {} ".format(self.batch_size))
self.input_size = config.input_size
self.hidden_size = config.hidden_size
self.num_step = config.num_step
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / self.hidden_size) ** 0.5
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
self.rnn_bw = P.DynamicRNN(forget_bias=0.0)
self.w1 = Parameter(np.random.uniform(-k, k, \
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1")
self.w1_bw = Parameter(np.random.uniform(-k, k, \
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.reverse_seq = P.ReverseV2(axis=[0])
self.concat = P.Concat()
self.transpose = P.Transpose()
self.concat1 = P.Concat(axis=2)
self.dropout = nn.Dropout(0.7)
self.use_dropout = config.use_dropout
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
if self.use_dropout:
x = self.dropout(x)
x = self.cast(x, mstype.float16)
bw_x = self.reverse_seq(x)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq(y1_bw)
output = self.concat1((y1, y1_bw))
return output
class CTPN(nn.Cell):
"""
Define CTPN network
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config, is_training=True):
super(CTPN, self).__init__()
self.config = config
self.is_training = is_training
self.num_step = config.num_step
self.input_size = config.input_size
self.batch_size = config.batch_size
self.hidden_size = config.hidden_size
self.vgg16_feature_extractor = VGG16FeatureExtraction()
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cast = P.Cast()
# rpn block
self.rpn_with_loss = RPN(config,
self.batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)
self.anchor_generator = AnchorGenerator(config)
self.featmap_size = config.feature_shapes
self.anchor_list = self.get_anchors(self.featmap_size)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
# (1,3,600,900)
x = self.vgg16_feature_extractor(img_data)
x = self.conv(x)
x = self.cast(x, mstype.float16)
# (1, 512, 38, 57)
x = self.transpose(x, (0, 2, 1, 3))
x = self.reshape(x, (-1, self.input_size, self.num_step))
x = self.transpose(x, (2, 0, 1))
# (57, 38, 512)
x = self.rnn(x)
# (57, 38, 256)
#x = self.cast(x, mstype.float32)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x,
img_metas,
self.anchor_list,
gt_bboxes,
gt_labels,
gt_valids)
if self.training:
return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
return proposal, proposal_mask
def get_anchors(self, featmap_size):
anchors = self.anchor_generator.grid_anchors(featmap_size)
return Tensor(anchors, mstype.float16)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for deeptext"""
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, base_step):
"""dynamic learning rate generator"""
base_lr = config.base_lr
total_steps = int(base_step * config.total_epoch)
warmup_steps = config.warmup_step
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

@ -0,0 +1,153 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn training network wrapper."""
import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rpn_cls_loss = cb_params.net_outputs[1].asnumpy()
rpn_reg_loss = cb_params.net_outputs[2].asnumpy()
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum / self.count
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"%
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rpn_cls_loss, rpn_reg_loss))
loss_file.write("\n")
loss_file.close()
class LossNet(nn.Cell):
"""FasterRcnn loss method"""
def construct(self, x1, x2, x3):
return x1
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
self.reduce_flag = reduce_flag
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
weights = self.weights
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================import numpy as np
import numpy as np
from src.text_connector.utils import clip_boxes, fit_y
from src.text_connector.get_successions import get_successions
def connect_text_lines(text_proposals, scores, size):
"""
Connect text lines
Args:
text_proposals(numpy.array): Predict text proposals.
scores(numpy.array): Bbox predicts scores.
size(numpy.array): Image size.
Returns:
text_recs(numpy.array): Text boxes after connect.
"""
graph = get_successions(text_proposals, scores, size)
text_lines = np.zeros((len(graph), 5), np.float32)
for index, indices in enumerate(graph):
text_line_boxes = text_proposals[list(indices)]
x0 = np.min(text_line_boxes[:, 0])
x1 = np.max(text_line_boxes[:, 2])
offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5
lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
# the score of a text line is the average score of the scores
# of all text proposals contained in the text line
score = scores[list(indices)].sum() / float(len(indices))
text_lines[index, 0] = x0
text_lines[index, 1] = min(lt_y, rt_y)
text_lines[index, 2] = x1
text_lines[index, 3] = max(lb_y, rb_y)
text_lines[index, 4] = score
text_lines = clip_boxes(text_lines, size)
text_recs = np.zeros((len(text_lines), 9), np.float)
index = 0
for line in text_lines:
xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3]
text_recs[index, 0] = xmin
text_recs[index, 1] = ymin
text_recs[index, 2] = xmax
text_recs[index, 3] = ymax
text_recs[index, 4] = line[4]
index = index + 1
return text_recs

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from src.config import config
from src.text_connector.utils import nms
from src.text_connector.connect_text_lines import connect_text_lines
def filter_proposal(proposals, scores):
"""
Filter text proposals
Args:
proposals(numpy.array): Text proposals.
Returns:
proposals(numpy.array): Text proposals after filter.
"""
inds = np.where(scores > config.text_proposals_min_scores)[0]
keep_proposals = proposals[inds]
keep_scores = scores[inds]
sorted_inds = np.argsort(keep_scores.ravel())[::-1]
keep_proposals, keep_scores = keep_proposals[sorted_inds], keep_scores[sorted_inds]
nms_inds = nms(np.hstack((keep_proposals, keep_scores)), config.text_proposals_nms_thresh)
keep_proposals, keep_scores = keep_proposals[nms_inds], keep_scores[nms_inds]
return keep_proposals, keep_scores
def filter_boxes(boxes):
"""
Filter text boxes
Args:
boxes(numpy.array): Text boxes.
Returns:
boxes(numpy.array): Text boxes after filter.
"""
heights = np.zeros((len(boxes), 1), np.float)
widths = np.zeros((len(boxes), 1), np.float)
scores = np.zeros((len(boxes), 1), np.float)
index = 0
for box in boxes:
widths[index] = abs(box[2] - box[0])
heights[index] = abs(box[3] - box[1])
scores[index] = abs(box[4])
index += 1
return np.where((widths / heights > config.min_ratio) & (scores > config.line_min_score) &\
(widths > (config.text_proposals_width * config.min_num_proposals)))[0]
def detect(text_proposals, scores, size):
"""
Detect text boxes
Args:
text_proposals(numpy.array): Predict text proposals.
scores(numpy.array): Bbox predicts scores.
size(numpy.array): Image size.
Returns:
boxes(numpy.array): Text boxes after connect.
"""
keep_proposals, keep_scores = filter_proposal(text_proposals, scores)
connect_boxes = connect_text_lines(keep_proposals, keep_scores, size)
boxes = connect_boxes[filter_boxes(connect_boxes)]
return boxes

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save