!9023 add maskrcnn_mobilenetv1 network

From: @wang_hua_2019
Reviewed-by: 
Signed-off-by:
pull/9023/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 542b1ab76c

File diff suppressed because it is too large Load Diff

@ -0,0 +1,134 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for MaskRcnn"""
import os
import argparse
import time
import numpy as np
from pycocotools.coco import COCO
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.maskrcnn_mobilenetv1.mask_rcnn_mobilenetv1 import Mask_Rcnn_Mobilenetv1
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
set_seed(1)
parser = argparse.ArgumentParser(description="MaskRcnn evaluation")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
def MaskRcnn_eval(dataset_path, ckpt_path, ann_file):
"""MaskRcnn evaluation."""
ds = create_maskrcnn_dataset(dataset_path, batch_size=config.test_batch_size, is_training=False)
net = Mask_Rcnn_Mobilenetv1(config)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
eval_iter = 0
total = ds.get_dataset_size()
outputs = []
dataset_coco = COCO(ann_file)
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
max_num = 128
for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1):
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
gt_mask = data["mask"]
start = time.time()
# run net
output = net(Tensor(img_data), Tensor(img_metas), Tensor(gt_bboxes), Tensor(gt_labels), Tensor(gt_num),
Tensor(gt_mask))
end = time.time()
print("Iter {} cost time {}".format(eval_iter, end - start))
# output
all_bbox = output[0]
all_label = output[1]
all_mask = output[2]
all_mask_fb = output[3]
for j in range(config.test_batch_size):
all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :])
all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :])
all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :])
all_mask_fb_squee = np.squeeze(all_mask_fb.asnumpy()[j, :, :, :])
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :]
all_labels_tmp_mask = all_label_squee[all_mask_squee]
all_mask_fb_tmp_mask = all_mask_fb_squee[all_mask_squee, :, :]
if all_bboxes_tmp_mask.shape[0] > max_num:
inds = np.argsort(-all_bboxes_tmp_mask[:, -1])
inds = inds[:max_num]
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds]
all_labels_tmp_mask = all_labels_tmp_mask[inds]
all_mask_fb_tmp_mask = all_mask_fb_tmp_mask[inds]
bbox_results = bbox2result_1image(all_bboxes_tmp_mask, all_labels_tmp_mask, config.num_classes)
segm_results = get_seg_masks(all_mask_fb_tmp_mask, all_bboxes_tmp_mask, all_labels_tmp_mask, img_metas[j],
True, config.num_classes)
outputs.append((bbox_results, segm_results))
eval_iter = eval_iter + 1
eval_types = ["bbox", "segm"]
result_files = results2json(dataset_coco, outputs, "./results.pkl")
coco_eval(result_files, eval_types, dataset_coco, single_result=False)
if __name__ == '__main__':
prefix = "MaskRcnn_eval.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
print("Start Eval!")
MaskRcnn_eval(mindrecord_file, args_opt.checkpoint_path, args_opt.ann_file)

@ -0,0 +1,22 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.maskrcnn_mobilenetv1.mask_rcnn_mobilenetv1 import Mask_Rcnn_Mobilenetv1
from src.config import config
def create_network(name, *args, **kwargs):
if name == "maskrcnn_mobilenetv1":
return Mask_Rcnn_Mobilenetv1(config=config)
raise NotImplementedError(f"{name} is not implemented in the repo")

@ -0,0 +1,73 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$2
echo $PATH1
echo $PATH2
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
ulimit -u unlimited
export HCCL_CONNECT_TIMEOUT=600
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
echo 3 > /proc/sys/vm/drop_caches
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ $RANK_SIZE`
gap=`expr $avg \- 1`
for((i=0; i<${DEVICE_NUM}; i++))
do
start=`expr $i \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python train.py --do_train=True --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM \
--pre_trained=$PATH2 &> log &
cd ..
done

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [ANN_FILE] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
echo $PATH1
echo $PATH2
if [ ! -f $PATH1 ]
then
echo "error: ANN_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start eval for device $DEVICE_ID"
python eval.py --device_id=$DEVICE_ID --ann_file=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train.sh [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: PRETRAINED_PATH=$PATH1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --do_train=True --device_id=$DEVICE_ID --pre_trained=$PATH1 &> log &
cd ..

@ -0,0 +1,159 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"img_width": 1280,
"img_height": 768,
"keep_ratio": True,
"flip_ratio": 0.5,
"photo_ratio": 0.5,
"expand_ratio": 1.0,
"max_instance_count": 128,
"mask_shape": (28, 28),
# anchor
"feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)],
"anchor_scales": [8],
"anchor_ratios": [0.5, 1.0, 2.0],
"anchor_strides": [4, 8, 16, 32, 64],
"num_anchors": 3,
# fpn
"fpn_in_channels": [128, 256, 512, 1024],
"fpn_out_channels": 256,
"fpn_num_outs": 5,
# rpn
"rpn_in_channels": 256,
"rpn_feat_channels": 256,
"rpn_loss_cls_weight": 1.0,
"rpn_loss_reg_weight": 1.0,
"rpn_cls_out_channels": 1,
"rpn_target_means": [0., 0., 0., 0.],
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0],
# bbox_assign_sampler
"neg_iou_thr": 0.3,
"pos_iou_thr": 0.7,
"min_pos_iou": 0.3,
"num_bboxes": 245520,
"num_gts": 128,
"num_expected_neg": 256,
"num_expected_pos": 128,
# proposal
"activate_num_classes": 2,
"use_sigmoid_cls": True,
# roi_align
"roi_layer": dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2),
"roi_align_out_channels": 256,
"roi_align_featmap_strides": [4, 8, 16, 32],
"roi_align_finest_scale": 56,
"roi_sample_num": 640,
# bbox_assign_sampler_stage2
"neg_iou_thr_stage2": 0.5,
"pos_iou_thr_stage2": 0.5,
"min_pos_iou_stage2": 0.5,
"num_bboxes_stage2": 2000,
"num_expected_pos_stage2": 128,
"num_expected_neg_stage2": 512,
"num_expected_total_stage2": 512,
# rcnn
"rcnn_num_layers": 2,
"rcnn_in_channels": 256,
"rcnn_fc_out_channels": 1024,
"rcnn_mask_out_channels": 256,
"rcnn_loss_cls_weight": 1,
"rcnn_loss_reg_weight": 1,
"rcnn_loss_mask_fb_weight": 1,
"rcnn_target_means": [0., 0., 0., 0.],
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2],
# train proposal
"rpn_proposal_nms_across_levels": False,
"rpn_proposal_nms_pre": 2000,
"rpn_proposal_nms_post": 2000,
"rpn_proposal_max_num": 2000,
"rpn_proposal_nms_thr": 0.7,
"rpn_proposal_min_bbox_size": 0,
# test proposal
"rpn_nms_across_levels": False,
"rpn_nms_pre": 1000,
"rpn_nms_post": 1000,
"rpn_max_num": 1000,
"rpn_nms_thr": 0.7,
"rpn_min_bbox_min_size": 0,
"test_score_thr": 0.05,
"test_iou_thr": 0.5,
"test_max_per_img": 100,
"test_batch_size": 2,
"rpn_head_loss_type": "CrossEntropyLoss",
"rpn_head_use_sigmoid": True,
"rpn_head_weight": 1.0,
"mask_thr_binary": 0.5,
# LR
"base_lr": 0.04,
"base_step": 58633,
"total_epoch": 13,
"warmup_step": 500,
"warmup_mode": "linear",
"warmup_ratio": 1/3.0,
"sgd_momentum": 0.9,
# train
"batch_size": 2,
"loss_scale": 1,
"momentum": 0.91,
"weight_decay": 1e-4,
"pretrain_epoch_size": 0,
"epoch_size": 12,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 12,
"save_checkpoint_path": "./",
"mindrecord_dir": "/home/mask_rcnn/MindRecord_COCO2017_Train",
"coco_root": "/home/mask_rcnn/coco2017/",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instance_set": "annotations/instances_{}.json",
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81
})

File diff suppressed because it is too large Load Diff

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for maskrcnn_mobilenetv1"""
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, rank_size=1, start_steps=0):
"""dynamic learning rate generator"""
base_lr = config.base_lr
base_step = (config.base_step // rank_size) + rank_size
total_steps = int(base_step * config.total_epoch)
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
learning_rate = lr[start_steps:]
return learning_rate

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn Init."""
from .mobilenetv1 import MobileNetV1, MobileNetV1_FeatureSelector
from .bbox_assign_sample import BboxAssignSample
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
from .proposal_generator import Proposal
from .rcnn_cls import RcnnCls
from .rcnn_mask import RcnnMask
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator
__all__ = [
"MobileNetV1", "MobileNetV1_FeatureSelector", "BboxAssignSample", "BboxAssignSampleForRcnn",
"FeatPyramidNeck", "Proposal", "RcnnCls", "RcnnMask",
"RPN", "SingleRoIExtractor", "AnchorGenerator"
]

@ -0,0 +1,84 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for MasKRcnn."""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
"""Anchor generator init method."""
self.base_size = base_size
self.scales = np.array(scales)
self.ratios = np.array(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
def gen_base_anchors(self):
"""Generate a single anchor."""
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = np.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1)
base_anchors = np.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1).round()
return base_anchors
def _meshgrid(self, x, y, row_major=True):
"""Generate grid."""
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
yy = np.repeat(y, len(x))
if row_major:
return xx, yy
return yy, xx
def grid_anchors(self, featmap_size, stride=16):
"""Generate anchor list."""
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = shifts.astype(base_anchors.dtype)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.reshape(-1, 4)
return all_anchors

@ -0,0 +1,164 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn positive and negative sample screening for RPN."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler defination.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSample(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
self.zero_thr = Tensor(0.0, mstype.float16)
self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
self.num_expected_pos = cfg.num_expected_pos
self.num_expected_neg = cfg.num_expected_neg
self.add_gt_as_proposals = add_gt_as_proposals
if self.add_gt_as_proposals:
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
self.less(max_overlaps_w_gt, self.neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
assigned_gt_inds4 = assigned_gt_inds3
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
pos_bboxes_ = self.gatherND(bboxes, pos_index)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
total_index = self.concat((pos_index, neg_index))
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
labels_total, self.cast(label_weights_total, mstype.bool_)

@ -0,0 +1,221 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn tpositive and negative sample screening for Rcnn."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
class BboxAssignSampleForRcnn(nn.Cell):
"""
Bbox assigner and sampler defination.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, multiple output tensors.
Examples:
BboxAssignSampleForRcnn(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSampleForRcnn, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = cfg.neg_iou_thr_stage2
self.pos_iou_thr = cfg.pos_iou_thr_stage2
self.min_pos_iou = cfg.min_pos_iou_stage2
self.num_gts = cfg.num_gts
self.num_bboxes = num_bboxes
self.num_expected_pos = cfg.num_expected_pos_stage2
self.num_expected_neg = cfg.num_expected_neg_stage2
self.num_expected_total = cfg.num_expected_total_stage2
self.add_gt_as_proposals = add_gt_as_proposals
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts),
dtype=np.int32))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
self.concat_axis1 = P.Concat(axis=1)
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
# Check
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
# Init tensor
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
self.reshape_shape_pos = (self.num_expected_pos, 1)
self.reshape_shape_neg = (self.num_expected_neg, 1)
self.scalar_zero = Tensor(0.0, dtype=mstype.float16)
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16)
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16)
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16)
self.expand_dims = P.ExpandDims()
self.split = P.Split(axis=1, output_num=4)
self.concat_last_axis = P.Concat(axis=-1)
self.round = P.Round()
self.image_h_w = Tensor([cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width], dtype=mstype.float16)
self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2)
self.crop_and_resize = P.CropAndResize(method="bilinear_v2")
self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1])
self.squeeze_mask_last = P.Squeeze(axis=-1)
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids, gt_masks_i):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), \
gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \
bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
self.scalar_zero),
self.less(max_overlaps_w_gt,
self.scalar_neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_ac_j = overlaps[j:j+1:1, ::]
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou)
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j))
pos_mask_j = self.logicaland(temp1, temp2)
assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores)
bboxes = self.concat((gt_bboxes_i, bboxes))
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores)
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5))
# Get pos index
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
pos_index = self.reshape(pos_index, self.reshape_shape_pos)
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
pos_index = pos_index * valid_pos_index
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
# Get neg index
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
neg_index = self.reshape(neg_index, self.reshape_shape_neg)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
neg_index = neg_index * valid_neg_index
pos_bboxes_ = self.gatherND(bboxes, pos_index)
neg_bboxes_ = self.gatherND(bboxes, neg_index)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
# assign positive ROIs to gt masks
# Pick the right front and background mask for each ROI
roi_pos_masks_fb = self.gatherND(gt_masks_i, pos_assigned_gt_index)
pos_masks_fb = self.cast(roi_pos_masks_fb, mstype.float32)
# compute mask targets
x1, y1, x2, y2 = self.split(pos_bboxes_)
boxes = self.concat_last_axis((y1, x1, y2, x2))
# normalized box coordinate
boxes = boxes / self.image_h_w
box_ids = self.range()
pos_masks_fb = self.expand_dims(pos_masks_fb, -1)
boxes = self.cast(boxes, mstype.float32)
pos_masks_fb = self.crop_and_resize(pos_masks_fb, boxes, box_ids, self.mask_shape)
# Remove the extra dimension from masks.
pos_masks_fb = self.squeeze_mask_last(pos_masks_fb)
# convert gt masks targets be 0 or 1 to use with binary cross entropy loss.
pos_masks_fb = self.round(pos_masks_fb)
pos_masks_fb = self.cast(pos_masks_fb, mstype.float16)
total_bboxes = self.concat((pos_bboxes_, neg_bboxes_))
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask))
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask))
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
total_mask = self.concat((valid_pos_index, valid_neg_index))
return total_bboxes, total_deltas, total_labels, total_mask, pos_bboxes_, pos_masks_fb, \
pos_gt_labels, valid_pos_index

@ -0,0 +1,111 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn feature pyramid network."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
def bias_init_zeros(shape):
"""Bias init method."""
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
shape_bias = (out_channels,)
biass = bias_init_zeros(shape_bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)
class FeatPyramidNeck(nn.Cell):
"""
Feature pyramid network cell, usually uses as network neck.
Applies the convolution on multiple, input feature maps
and output feature map with same channel size. if required num of
output larger then num of inputs, add extra maxpooling for further
downsampling;
Args:
in_channels (tuple) - Channel size of input feature maps.
out_channels (int) - Channel size output.
num_outs (int) - Num of output features.
Returns:
Tuple, with tensors of same channel size.
Examples:
neck = FeatPyramidNeck([100,200,300], 50, 4)
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)),
dtype=np.float32) \
for i, c in enumerate(config.fpn_in_channels))
x = neck(input_data)
"""
def __init__(self,
in_channels,
out_channels,
num_outs):
super(FeatPyramidNeck, self).__init__()
self.num_outs = num_outs
self.in_channels = in_channels
self.fpn_layer = len(self.in_channels)
assert not self.num_outs < len(in_channels)
self.lateral_convs_list_ = []
self.fpn_convs_ = []
for _, channel in enumerate(in_channels):
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
self.lateral_convs_list_.append(l_conv)
self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
self.interpolate1 = P.ResizeBilinear((48, 80))
self.interpolate2 = P.ResizeBilinear((96, 160))
self.interpolate3 = P.ResizeBilinear((192, 320))
self.cast = P.Cast()
self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same")
def construct(self, inputs):
x = ()
for i in range(self.fpn_layer):
x += (self.lateral_convs_list[i](inputs[i]),)
y = (x[3],)
y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), mstype.float16),)
y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), mstype.float16),)
y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), mstype.float16),)
z = ()
for i in range(self.fpn_layer - 1, -1, -1):
z = z + (y[i],)
outs = ()
for i in range(self.fpn_layer):
outs = outs + (self.fpn_convs_list[i](z[i]),)
for i in range(self.num_outs - self.fpn_layer):
outs = outs + (self.maxpool(outs[3]),)
return outs

@ -0,0 +1,117 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobilenetV1 backbone."""
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
output = []
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode='same',
group=1 if not depthwise else in_channel))
output.append(nn.BatchNorm2d(out_channel))
if activation:
output.append(nn.get_activation(activation))
return nn.SequentialCell(output)
class MobileNetV1(nn.Cell):
"""
MobileNet V1 backbone
"""
def __init__(self, class_num=1001, features_only=False):
super(MobileNetV1, self).__init__()
self.features_only = features_only
cnn = [
conv_bn_relu(3, 32, 3, 2, False),
conv_bn_relu(32, 32, 3, 1, True),
conv_bn_relu(32, 64, 1, 1, False),
conv_bn_relu(64, 64, 3, 2, True),
conv_bn_relu(64, 128, 1, 1, False),
conv_bn_relu(128, 128, 3, 1, True),
conv_bn_relu(128, 128, 1, 1, False),
conv_bn_relu(128, 128, 3, 2, True),
conv_bn_relu(128, 256, 1, 1, False),
conv_bn_relu(256, 256, 3, 1, True),
conv_bn_relu(256, 256, 1, 1, False),
conv_bn_relu(256, 256, 3, 2, True),
conv_bn_relu(256, 512, 1, 1, False),
conv_bn_relu(512, 512, 3, 1, True),
conv_bn_relu(512, 512, 1, 1, False),
conv_bn_relu(512, 512, 3, 1, True),
conv_bn_relu(512, 512, 1, 1, False),
conv_bn_relu(512, 512, 3, 1, True),
conv_bn_relu(512, 512, 1, 1, False),
conv_bn_relu(512, 512, 3, 1, True),
conv_bn_relu(512, 512, 1, 1, False),
conv_bn_relu(512, 512, 3, 1, True),
conv_bn_relu(512, 512, 1, 1, False),
conv_bn_relu(512, 512, 3, 2, True),
conv_bn_relu(512, 1024, 1, 1, False),
conv_bn_relu(1024, 1024, 3, 1, True),
conv_bn_relu(1024, 1024, 1, 1, False),
]
if self.features_only:
self.network = nn.CellList(cnn)
else:
self.network = nn.SequentialCell(cnn)
self.fc = nn.Dense(1024, class_num)
def construct(self, x):
output = x
if self.features_only:
features = ()
for block in self.network:
output = block(output)
features = features + (output,)
return features
output = self.network(x)
output = P.ReduceMean()(output, (2, 3))
output = self.fc(output)
return output
class FeatureSelector(nn.Cell):
"""
Select specific layers from en entire feature list
"""
def __init__(self, feature_idxes):
super(FeatureSelector, self).__init__()
self.feature_idxes = feature_idxes
def construct(self, feature_list):
selected = ()
for i in self.feature_idxes:
selected = selected + (feature_list[i],)
return selected
class MobileNetV1_FeatureSelector(nn.Cell):
"""
Mobilenet v1 feature selector
"""
def __init__(self, class_num=1001, features_only=False):
super(MobileNetV1_FeatureSelector, self).__init__()
self.mobilenet_v1 = MobileNetV1(class_num=class_num, features_only=features_only)
self.selector = FeatureSelector([6, 10, 22, 26])
def construct(self, x):
features = self.mobilenet_v1(x)
features = self.selector(features)
return features

@ -0,0 +1,195 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn proposal generator."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
class Proposal(nn.Cell):
"""
Proposal subnet.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_classes (int) - Class number.
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
Returns:
Tuple, tuple of output tensor,(proposal, mask).
Examples:
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
"""
def __init__(self,
config,
batch_size,
num_classes,
use_sigmoid_cls,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)
):
super(Proposal, self).__init__()
cfg = config
self.batch_size = batch_size
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else:
self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)
if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))
self.num_pre = cfg.rpn_proposal_nms_pre
self.min_box_size = cfg.rpn_proposal_min_bbox_size
self.nms_thr = cfg.rpn_proposal_nms_thr
self.nms_post = cfg.rpn_proposal_nms_post
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
self.max_num = cfg.rpn_proposal_max_num
self.num_levels = cfg.fpn_num_outs
# Op Define
self.squeeze = P.Squeeze()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.feature_shapes = cfg.feature_shapes
self.transpose_shape = (1, 2, 0)
self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \
means=self.target_means, \
stds=self.target_stds)
self.nms = P.NMSWithMask(self.nms_thr)
self.concat_axis0 = P.Concat(axis=0)
self.concat_axis1 = P.Concat(axis=1)
self.split = P.Split(axis=1, output_num=5)
self.min = P.Minimum()
self.gatherND = P.GatherNd()
self.slice = P.Slice()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
self.tile = P.Tile()
self.set_train_local(config, training=True)
self.multi_10 = Tensor(10.0, mstype.float16)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
self.topK_stage1 = ()
self.topK_shape = ()
total_max_topk_input = 0
if not self.training_local:
self.num_pre = cfg.rpn_nms_pre
self.min_box_size = cfg.rpn_min_bbox_min_size
self.nms_thr = cfg.rpn_nms_thr
self.nms_post = cfg.rpn_nms_post
self.nms_across_levels = cfg.rpn_nms_across_levels
self.max_num = cfg.rpn_max_num
for shp in self.feature_shapes:
k_num = min(self.num_pre, (shp[0] * shp[1] * 3))
total_max_topk_input += k_num
self.topK_stage1 += (k_num,)
self.topK_shape += ((k_num, 1),)
self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65536.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
masks_tuple = ()
for img_id in range(self.batch_size):
cls_score_list = ()
bbox_pred_list = ()
for i in range(self.num_levels):
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::])
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::])
cls_score_list = cls_score_list + (rpn_cls_score_i,)
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,)
proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list)
proposals_tuple += (proposals,)
masks_tuple += (masks,)
return proposals_tuple, masks_tuple
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
"""Get proposal boundingbox."""
mlvl_proposals = ()
mlvl_mask = ()
for idx in range(self.num_levels):
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape)
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape)
anchors = mlvl_anchors[idx]
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])
topk_inds = self.reshape(topk_inds, self.topK_shape[idx])
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
proposals, _, mask_valid = self.nms(proposals_decode)
mlvl_proposals = mlvl_proposals + (proposals,)
mlvl_mask = mlvl_mask + (mask_valid,)
proposals = self.concat_axis0(mlvl_proposals)
masks = self.concat_axis0(mlvl_mask)
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
proposals = self.gatherND(proposals, topk_inds)
masks = self.gatherND(masks, topk_inds)
return proposals, masks

@ -0,0 +1,178 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn Rcnn classification and box regression network."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16),
name="weight")
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias")
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
def construct(self, x):
output = self.bias_add(self.matmul(x, self.weight), self.bias)
return output
class FpnCls(nn.Cell):
"""dense layer of classification and box head"""
def __init__(self, input_channels, output_channels, num_classes, pool_size):
super(FpnCls, self).__init__()
representation_size = input_channels * pool_size * pool_size
shape_0 = (output_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor()
shape_1 = (output_channels, output_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor()
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1)
cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1],
dtype=mstype.float16).to_tensor()
reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1],
dtype=mstype.float16).to_tensor()
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight)
self.relu = P.ReLU()
self.flatten = P.Flatten()
def construct(self, x):
# two share fc layer
x = self.flatten(x)
x = self.relu(self.shared_fc_0(x))
x = self.relu(self.shared_fc_1(x))
# classifier head
cls_scores = self.cls_scores(x)
# bbox head
reg_scores = self.reg_scores(x)
return cls_scores, reg_scores
class RcnnCls(nn.Cell):
"""
Rcnn for classification and box regression subnet.
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
num_classes (int) - Class number.
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
Returns:
Tuple, tuple of output tensor.
Examples:
RcnnCls(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
"""
def __init__(self,
config,
batch_size,
num_classes,
target_means=(0., 0., 0., 0.),
target_stds=(0.1, 0.1, 0.2, 0.2)
):
super(RcnnCls, self).__init__()
cfg = config
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16))
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
self.target_means = target_means
self.target_stds = target_stds
self.num_classes = num_classes
self.in_channels = cfg.rcnn_in_channels
self.train_batch_size = batch_size
self.test_batch_size = cfg.test_batch_size
self.fpn_cls = FpnCls(self.in_channels, self.rcnn_fc_out_channels, self.num_classes, cfg.roi_layer["out_size"])
self.relu = P.ReLU()
self.logicaland = P.LogicalAnd()
self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(beta=1.0)
self.loss_mask = P.SigmoidCrossEntropyWithLogits()
self.reshape = P.Reshape()
self.onehot = P.OneHot()
self.greater = P.Greater()
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()
self.gather = P.GatherNd()
self.argmax = P.ArgMaxWithValue(axis=1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.value = Tensor(1.0, mstype.float16)
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size
def construct(self, featuremap, bbox_targets, labels, mask):
x_cls, x_reg = self.fpn_cls(featuremap)
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
loss_cls, loss_reg = self.loss(x_cls, x_reg,
bbox_targets, bbox_weights,
labels,
mask)
out = (loss_cls, loss_reg)
else:
out = (x_cls, x_reg)
return out
def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights):
"""Loss method."""
# loss_cls
loss_cls, _ = self.loss_cls(cls_score, labels)
weights = self.cast(weights, mstype.float16)
loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
# loss_reg
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
loss_reg = self.sum_loss(loss_reg, (2,))
loss_reg = loss_reg * bbox_weights
loss_reg = loss_reg / self.sum_loss(weights, (0,))
loss_reg = self.sum_loss(loss_reg, (0, 1))
return loss_cls, loss_reg

@ -0,0 +1,168 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn Rcnn for mask network."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
def _convTanspose(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""ConvTranspose wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
return nn.Conv2dTranspose(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
class FpnMask(nn.Cell):
"""conv layers of mask head"""
def __init__(self, input_channels, output_channels, num_classes):
super(FpnMask, self).__init__()
self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu1 = P.ReLU()
self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu2 = P.ReLU()
self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu3 = P.ReLU()
self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu4 = P.ReLU()
self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2, stride=2, pad_mode="valid")
self.mask_relu5 = P.ReLU()
self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1, pad_mode="valid")
def construct(self, x):
x = self.mask_conv1(x)
x = self.mask_relu1(x)
x = self.mask_conv2(x)
x = self.mask_relu2(x)
x = self.mask_conv3(x)
x = self.mask_relu3(x)
x = self.mask_conv4(x)
x = self.mask_relu4(x)
x = self.mask_deconv5(x)
x = self.mask_relu5(x)
x = self.mask_conv6(x)
return x
class RcnnMask(nn.Cell):
"""
Rcnn for mask subnet.
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
num_classes (int) - Class number.
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
Returns:
Tuple, tuple of output tensor.
Examples:
RcnnMask(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
"""
def __init__(self,
config,
batch_size,
num_classes,
target_means=(0., 0., 0., 0.),
target_stds=(0.1, 0.1, 0.2, 0.2)
):
super(RcnnMask, self).__init__()
cfg = config
self.rcnn_loss_mask_fb_weight = Tensor(np.array(cfg.rcnn_loss_mask_fb_weight).astype(np.float16))
self.rcnn_mask_out_channels = cfg.rcnn_mask_out_channels
self.target_means = target_means
self.target_stds = target_stds
self.num_classes = num_classes
self.in_channels = cfg.rcnn_in_channels
self.fpn_mask = FpnMask(self.in_channels, self.rcnn_mask_out_channels, self.num_classes)
self.logicaland = P.LogicalAnd()
self.loss_mask = P.SigmoidCrossEntropyWithLogits()
self.onehot = P.OneHot()
self.greater = P.Greater()
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.num_bboxes = cfg.num_expected_pos_stage2 * batch_size
rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
self.mean_loss = P.ReduceMean()
def construct(self, mask_featuremap, labels=None, mask=None, mask_fb_targets=None):
x_mask_fb = self.fpn_mask(mask_featuremap)
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
mask_fb_targets = self.tile(self.expandims(mask_fb_targets, 1), (1, self.num_classes, 1, 1))
loss_mask_fb = self.loss(x_mask_fb, bbox_weights, mask, mask_fb_targets)
out = loss_mask_fb
else:
out = x_mask_fb
return out
def loss(self, masks_fb_pred, bbox_weights, weights, masks_fb_targets):
"""Loss method."""
weights = self.cast(weights, mstype.float16)
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background
# loss_mask_fb
masks_fb_targets = self.cast(masks_fb_targets, mstype.float16)
loss_mask_fb = self.loss_mask(masks_fb_pred, masks_fb_targets)
loss_mask_fb = self.mean_loss(loss_mask_fb, (2, 3))
loss_mask_fb = loss_mask_fb * bbox_weights
loss_mask_fb = loss_mask_fb / self.sum_loss(weights, (0,))
loss_mask_fb = self.sum_loss(loss_mask_fb, (0, 1))
return loss_mask_fb

@ -0,0 +1,186 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn ROIAlign module."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.nn import layer as L
from mindspore.common.tensor import Tensor
class ROIAlign(nn.Cell):
"""
Extract RoI features from mulitple feature map.
Args:
out_size_h (int) - RoI height.
out_size_w (int) - RoI width.
spatial_scale (int) - RoI spatial scale.
sample_num (int) - RoI sample number.
roi_align_mode (int)- RoI align mode
"""
def __init__(self,
out_size_h,
out_size_w,
spatial_scale,
sample_num=0,
roi_align_mode=1):
super(ROIAlign, self).__init__()
self.out_size = (out_size_h, out_size_w)
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1],
self.spatial_scale, self.sample_num, roi_align_mode)
def construct(self, features, rois):
return self.align_op(features, rois)
def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
self.out_size, self.spatial_scale, self.sample_num)
return format_str
class SingleRoIExtractor(nn.Cell):
"""
Extract RoI features from a single level feature map.
If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
config (dict): Config
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
batch_size (int) Batchsize.
finest_scale (int): Scale threshold of mapping to level 0.
mask (bool): Specify ROIAlign for cls or mask branch
"""
def __init__(self,
config,
roi_layer,
out_channels,
featmap_strides,
batch_size=1,
finest_scale=56,
mask=False):
super(SingleRoIExtractor, self).__init__()
cfg = config
self.train_batch_size = batch_size
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.out_size = roi_layer['mask_out_size'] if mask else roi_layer['out_size']
self.mask = mask
self.sample_num = roi_layer['sample_num']
self.roi_layers = self.build_roi_layers(self.featmap_strides)
self.roi_layers = L.CellList(self.roi_layers)
self.sqrt = P.Sqrt()
self.log = P.Log()
self.finest_scale_ = finest_scale
self.clamp = C.clip_by_value
self.cast = P.Cast()
self.equal = P.Equal()
self.select = P.Select()
_mode_16 = False
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32
self.set_train_local(cfg, training=True)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
# Init tensor
roi_sample_num = cfg.num_expected_pos_stage2 if self.mask else cfg.roi_sample_num
self.batch_size = roi_sample_num if self.training_local else cfg.rpn_max_num
self.batch_size = self.train_batch_size*self.batch_size \
if self.training_local else cfg.test_batch_size*self.batch_size
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype))
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_
self.finest_scale = Tensor(finest_scale)
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6))
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32))
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1))
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2)
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels,
self.out_size, self.out_size)), dtype=self.dtype))
def num_inputs(self):
return len(self.featmap_strides)
def init_weights(self):
pass
def log2(self, value):
return self.log(value) / self.log(self.twos)
def build_roi_layers(self, featmap_strides):
roi_layers = []
for s in featmap_strides:
layer_cls = ROIAlign(self.out_size, self.out_size,
spatial_scale=1 / s,
sample_num=self.sample_num,
roi_align_mode=0)
roi_layers.append(layer_cls)
return roi_layers
def _c_map_roi_levels(self, rois):
"""Map rois to corresponding feature levels by scales.
- scale < finest_scale * 2: level 0
- finest_scale * 2 <= scale < finest_scale * 4: level 1
- finest_scale * 4 <= scale < finest_scale * 8: level 2
- scale >= finest_scale * 8: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones)
target_lvls = self.log2(scale / self.finest_scale + self.epslion)
target_lvls = P.Floor()(target_lvls)
target_lvls = self.cast(target_lvls, mstype.int32)
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels)
return target_lvls
def construct(self, rois, feat1, feat2, feat3, feat4):
feats = (feat1, feat2, feat3, feat4)
res = self.res_
target_lvls = self._c_map_roi_levels(rois)
for i in range(self.num_levels):
mask = self.equal(target_lvls, P.ScalarToArray()(i))
mask = P.Reshape()(mask, (-1, 1, 1, 1))
roi_feats_t = self.roi_layers[i](feats[i], rois)
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, self.out_size, self.out_size)),
mstype.bool_)
res = self.select(mask, roi_feats_t, res)
return res

@ -0,0 +1,221 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn training network wrapper."""
import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
time_stamp_init = False
time_stamp_first = 0
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in (0, 1):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return F.cast(new_grad, dt)
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
rcnn_mask_loss = cb_params.net_outputs[6].asnumpy()
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
self.rcnn_mask_loss_sum += float(rcnn_mask_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum/self.count
rcnn_loss = self.rcnn_loss_sum/self.count
rpn_cls_loss = self.rpn_cls_loss_sum/self.count
rpn_reg_loss = self.rpn_reg_loss_sum/self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count
rcnn_mask_loss = self.rcnn_mask_loss_sum/self.count
total_loss = rpn_loss + rcnn_loss
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, rcnn_mask_loss: %.5f, "
"total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_loss, total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
class LossNet(nn.Cell):
"""MaskRcnn loss method"""
def construct(self, x1, x2, x3, x4, x5, x6, x7):
return x1 + x2
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask):
loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self._backbone(x, img_shape, gt_bboxe, gt_label,
gt_num, gt_mask)
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6, loss7)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
self.reduce_flag = reduce_flag
self.hyper_map = C.HyperMap()
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self.backbone(x, img_shape, gt_bboxe, gt_label,
gt_num, gt_mask)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6, loss7

File diff suppressed because it is too large Load Diff

@ -0,0 +1,142 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train MaskRcnn and get checkpoint files."""
import os
import time
import argparse
import ast
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Momentum
from mindspore.common import set_seed
from src.maskrcnn_mobilenetv1.mask_rcnn_mobilenetv1 import Mask_Rcnn_Mobilenetv1
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.lr_schedule import dynamic_lr
set_seed(1)
parser = argparse.ArgumentParser(description="MaskRcnn training")
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default is false.")
parser.add_argument("--do_train", type=ast.literal_eval, default=True, help="Do train or not, default is true.")
parser.add_argument("--do_eval", type=ast.literal_eval, default=False, help="Do eval or not, default is false.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
if __name__ == '__main__':
print("Start train for maskrcnn_mobilenetv1!")
if not args_opt.do_eval and args_opt.run_distribute:
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
print("Start create dataset!")
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is MaskRcnn.mindrecord0, 1, ... file_num.
prefix = "MaskRcnn.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if rank == 0 and not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
raise Exception("coco_root not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
raise Exception("IMAGE_DIR or ANNO_PATH not exits.")
while not os.path.exists(mindrecord_file+".db"):
time.sleep(5)
if not args_opt.only_create_dataset:
loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as MaskRcnn.mindrecord0.
dataset = create_maskrcnn_dataset(mindrecord_file, batch_size=config.batch_size,
device_num=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size()
print("total images num: ", dataset_size)
print("Create dataset done!")
net = Mask_Rcnn_Mobilenetv1(config=config)
net = net.set_train()
load_path = args_opt.pre_trained
if load_path != "":
param_dict = load_checkpoint(load_path)
if config.pretrain_epoch_size == 0:
for item in list(param_dict.keys()):
if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
param_dict.pop(item)
load_param_into_net(net, param_dict)
loss = LossNet()
lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size),
mstype.float32)
opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
if config.save_checkpoint:
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max)
save_checkpoint_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=save_checkpoint_path, config=ckptconfig)
cb += [ckpoint_cb]
model = Model(net)
model.train(config.epoch_size, dataset, callbacks=cb)
Loading…
Cancel
Save