!11399 Add FasterRCNN training and evaluation on GPU

From: @dessyang
Reviewed-by: 
Signed-off-by:
pull/11399/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 37f00fdaab

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ import numpy as np
from pycocotools.coco import COCO
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.common import set_seed, Parameter
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config
@ -34,16 +34,22 @@ parser = argparse.ArgumentParser(description="FasterRcnn evaluation")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
def FasterRcnn_eval(dataset_path, ckpt_path, ann_file):
"""FasterRcnn evaluation."""
ds = create_fasterrcnn_dataset(dataset_path, batch_size=config.test_batch_size, is_training=False)
net = Faster_Rcnn_Resnet50(config)
param_dict = load_checkpoint(ckpt_path)
if args_opt.device_target == "GPU":
for key, value in param_dict.items():
tensor = value.asnumpy().astype(np.float32)
param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict)
net.set_train(False)

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_distribute_train_gpu.sh DEVICE_NUM PRETRAINED_PATH"
echo "for example: sh run_distribute_train_gpu.sh 8 /path/pretrain.ckpt"
echo "It is better to use absolute path."
echo "=============================================================================================================="
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [DEVICE_NUM] [PRETRAINED_PATH]"
exit 1
fi
rm -rf run_distribute_train
mkdir run_distribute_train
cp -rf ../src/ ../train.py ./run_distribute_train
cd run_distribute_train || exit
export RANK_SIZE=$1
PRETRAINED_PATH=$2
echo "start training on $RANK_SIZE devices"
mpirun -n $RANK_SIZE \
python train.py \
--run_distribute=True \
--device_target="GPU" \
--device_num=$RANK_SIZE \
--pre_trained=$PRETRAINED_PATH > log 2>&1 &

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval_gpu.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
echo $PATH1
echo $PATH2
if [ ! -f $PATH1 ]
then
echo "error: ANN_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start eval for device $DEVICE_ID"
python eval.py --device_target="GPU" --device_id=$DEVICE_ID --ann_file=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,11 +19,12 @@ import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore import context
class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler defination.
Bbox assigner and sampler definition.
Args:
config (dict): Config.
@ -45,12 +46,15 @@ class BboxAssignSample(nn.Cell):
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
_mode_16 = bool(context.get_context("device_target") == "Ascend")
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
self.batch_size = batch_size
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
self.zero_thr = Tensor(0.0, mstype.float16)
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, self.ms_type)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, self.ms_type)
self.min_pos_iou = Tensor(cfg.min_pos_iou, self.ms_type)
self.zero_thr = Tensor(0.0, self.ms_type)
self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
@ -92,9 +96,9 @@ class BboxAssignSample(nn.Cell):
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=self.dtype))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=self.dtype))
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
@ -129,7 +133,7 @@ class BboxAssignSample(nn.Cell):
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
@ -140,7 +144,7 @@ class BboxAssignSample(nn.Cell):
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
num_pos = self.cast(self.logicalnot(valid_pos_index), self.ms_type)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,11 +19,12 @@ import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore import context
class BboxAssignSampleForRcnn(nn.Cell):
"""
Bbox assigner and sampler defination.
Bbox assigner and sampler definition.
Args:
config (dict): Config.
@ -45,6 +46,9 @@ class BboxAssignSampleForRcnn(nn.Cell):
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSampleForRcnn, self).__init__()
cfg = config
_mode_16 = bool(context.get_context("device_target") == "Ascend")
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
self.batch_size = batch_size
self.neg_iou_thr = cfg.neg_iou_thr_stage2
self.pos_iou_thr = cfg.pos_iou_thr_stage2
@ -83,8 +87,8 @@ class BboxAssignSampleForRcnn(nn.Cell):
self.tile = P.Tile()
# Check
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=self.dtype))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=self.dtype))
# Init tensor
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
@ -94,18 +98,18 @@ class BboxAssignSampleForRcnn(nn.Cell):
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(self.dtype))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=self.dtype))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
self.reshape_shape_pos = (self.num_expected_pos, 1)
self.reshape_shape_neg = (self.num_expected_neg, 1)
self.scalar_zero = Tensor(0.0, dtype=mstype.float16)
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16)
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16)
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16)
self.scalar_zero = Tensor(0.0, dtype=self.ms_type)
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=self.ms_type)
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=self.ms_type)
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=self.ms_type)
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
@ -149,12 +153,12 @@ class BboxAssignSampleForRcnn(nn.Cell):
# Get pos index
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), self.ms_type)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1)
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), self.ms_type), -1)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
pos_index = self.reshape(pos_index, self.reshape_shape_pos)
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore import context
from .resnet50 import ResNetFea, ResidualBlockUsing
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
@ -50,6 +51,9 @@ class Faster_Rcnn_Resnet50(nn.Cell):
"""
def __init__(self, config):
super(Faster_Rcnn_Resnet50, self).__init__()
_mode_16 = bool(context.get_context("device_target") == "Ascend")
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
self.train_batch_size = config.batch_size
self.num_classes = config.num_classes
self.anchor_scales = config.anchor_scales
@ -157,7 +161,7 @@ class Faster_Rcnn_Resnet50(nn.Cell):
self.rpn_max_num = config.rpn_max_num
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16))
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.dtype))
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
@ -165,10 +169,10 @@ class Faster_Rcnn_Resnet50(nn.Cell):
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr)
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.dtype) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.dtype) * config.test_iou_thr)
self.test_max_per_img = config.test_max_per_img
self.nms_test = P.NMSWithMask(config.test_iou_thr)
self.softmax = P.Softmax(axis=1)
@ -183,9 +187,9 @@ class Faster_Rcnn_Resnet50(nn.Cell):
# Init tensor
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
dtype=np.float16) for i in range(self.train_batch_size)]
dtype=self.dtype) for i in range(self.train_batch_size)]
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.dtype) \
for i in range(self.test_batch_size)]
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
@ -276,7 +280,7 @@ class Faster_Rcnn_Resnet50(nn.Cell):
self.cast(x[3], mstype.float32))
roi_feats = self.cast(roi_feats, mstype.float16)
roi_feats = self.cast(roi_feats, self.ms_type)
rcnn_masks = self.concat(mask_tuple)
rcnn_masks = F.stop_gradient(rcnn_masks)
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
@ -420,7 +424,7 @@ class Faster_Rcnn_Resnet50(nn.Cell):
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors += (Tensor(anchors.astype(np.float16)),)
multi_level_anchors += (Tensor(anchors.astype(self.dtype)),)
return multi_level_anchors

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,16 +22,20 @@ from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def bias_init_zeros(shape):
"""Bias init method."""
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
if context.get_context("device_target") == "Ascend":
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
return Tensor(np.array(np.zeros(shape).astype(np.float32)))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16)
if context.get_context("device_target") == "Ascend":
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
else:
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32).to_tensor()
shape_bias = (out_channels,)
biass = bias_init_zeros(shape_bias)
return nn.Conv2d(in_channels, out_channels,

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,9 +22,6 @@ from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Proposal(nn.Cell):
"""
Proposal subnet.
@ -106,7 +103,11 @@ class Proposal(nn.Cell):
self.tile = P.Tile()
self.set_train_local(config, training=True)
self.multi_10 = Tensor(10.0, mstype.float16)
_mode_16 = bool(context.get_context("device_target") == "Ascend")
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
self.multi_10 = Tensor(10.0, self.ms_type)
def set_train_local(self, config, training=True):
"""Set training flag."""
@ -133,7 +134,10 @@ class Proposal(nn.Cell):
self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65536.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
if context.get_context("device_target") == "Ascend":
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
else:
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float32))
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
@ -164,16 +168,16 @@ class Proposal(nn.Cell):
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16)
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), self.ms_type)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), self.ms_type)
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])
topk_inds = self.reshape(topk_inds, self.topK_shape[idx])
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), self.ms_type)
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
@ -188,7 +192,7 @@ class Proposal(nn.Cell):
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
topk_mask = self.cast(self.topK_mask, self.ms_type)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,15 +21,19 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore import context
class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16))
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16))
if context.get_context("device_target") == "Ascend":
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16))
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16))
else:
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32))
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32))
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
@ -68,8 +72,11 @@ class Rcnn(nn.Cell):
):
super(Rcnn, self).__init__()
cfg = config
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16))
_mode_16 = bool(context.get_context("device_target") == "Ascend")
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(self.dtype))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(self.dtype))
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
self.target_means = target_means
self.target_stds = target_stds
@ -79,16 +86,16 @@ class Rcnn(nn.Cell):
self.test_batch_size = cfg.test_batch_size
shape_0 = (self.rcnn_fc_out_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=self.ms_type).to_tensor()
shape_1 = (self.rcnn_fc_out_channels, self.rcnn_fc_out_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=self.ms_type).to_tensor()
self.shared_fc_0 = DenseNoTranpose(representation_size, self.rcnn_fc_out_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(self.rcnn_fc_out_channels, self.rcnn_fc_out_channels, weights_1)
cls_weight = initializer('Normal', shape=[num_classes, self.rcnn_fc_out_channels][::-1],
dtype=mstype.float16)
dtype=self.ms_type).to_tensor()
reg_weight = initializer('Normal', shape=[num_classes * 4, self.rcnn_fc_out_channels][::-1],
dtype=mstype.float16)
dtype=self.ms_type).to_tensor()
self.cls_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(self.rcnn_fc_out_channels, num_classes * 4, reg_weight)
@ -110,13 +117,13 @@ class Rcnn(nn.Cell):
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.value = Tensor(1.0, mstype.float16)
self.value = Tensor(1.0, self.ms_type)
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
self.rmv_first_tensor = Tensor(rmv_first.astype(self.dtype))
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size
@ -134,7 +141,7 @@ class Rcnn(nn.Cell):
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16)
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask)
@ -149,12 +156,12 @@ class Rcnn(nn.Cell):
loss_print = ()
loss_cls, _ = self.loss_cls(cls_score, labels)
weights = self.cast(weights, mstype.float16)
weights = self.cast(weights, self.ms_type)
loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
self.ms_type)
bbox_weights = bbox_weights * self.rmv_first_tensor
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,12 +22,11 @@ from mindspore.ops import functional as F
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def weight_init_ones(shape):
"""Weight init."""
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16))
if context.get_context("device_target") == "Ascend":
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16))
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
@ -41,11 +40,12 @@ def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mod
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
_mode_16 = bool(context.get_context("device_target") == "Ascend")
dtype = np.float16 if _mode_16 else np.float32
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(dtype))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(dtype))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,7 +17,7 @@ import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import Tensor, context
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from .bbox_assign_sample import BboxAssignSample
@ -100,6 +100,9 @@ class RPN(nn.Cell):
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
_mode_16 = bool(context.get_context("device_target") == "Ascend")
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_type = mstype.float16 if _mode_16 else mstype.float32
self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = ()
self.feature_anchor_shape = ()
@ -114,7 +117,7 @@ class RPN(nn.Cell):
self.batch_size = batch_size
self.test_batch_size = cfg_rpn.test_batch_size
self.num_layers = 5
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
self.real_ratio = Tensor(np.ones((1, 1)).astype(self.dtype))
self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels,
num_anchors, cls_out_channels))
@ -123,15 +126,15 @@ class RPN(nn.Cell):
self.reshape = P.Reshape()
self.concat = P.Concat(axis=0)
self.fill = P.Fill()
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
self.placeh1 = Tensor(np.ones((1,)).astype(self.dtype))
self.trans_shape = (0, 2, 3, 1)
self.reshape_shape_reg = (-1, 4)
self.reshape_shape_cls = (-1,)
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(self.dtype))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(self.dtype))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(self.dtype))
self.num_bboxes = cfg_rpn.num_bboxes
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
self.CheckValid = P.CheckValid()
@ -142,9 +145,9 @@ class RPN(nn.Cell):
self.cast = P.Cast()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
self.loss = Tensor(np.zeros((1,)).astype(self.dtype))
self.clsloss = Tensor(np.zeros((1,)).astype(self.dtype))
self.regloss = Tensor(np.zeros((1,)).astype(self.dtype))
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
"""
@ -164,18 +167,18 @@ class RPN(nn.Cell):
shp_weight_conv = (feat_channels, in_channels, 3, 3)
shp_bias_conv = (feat_channels,)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16)
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=self.ms_type).to_tensor()
bias_conv = initializer(0, shape=shp_bias_conv, dtype=self.ms_type).to_tensor()
shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
shp_bias_cls = (num_anchors * cls_out_channels,)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16)
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=self.ms_type).to_tensor()
bias_cls = initializer(0, shape=shp_bias_cls, dtype=self.ms_type).to_tensor()
shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
shp_bias_reg = (num_anchors * 4,)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16)
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=self.ms_type).to_tensor()
bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor()
for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
@ -248,9 +251,9 @@ class RPN(nn.Cell):
mstype.bool_),
anchor_using_list, gt_valids_i)
bbox_weight = self.cast(bbox_weight, mstype.float16)
label = self.cast(label, mstype.float16)
label_weight = self.cast(label_weight, mstype.float16)
bbox_weight = self.cast(bbox_weight, self.ms_type)
label = self.cast(label, self.ms_type)
label_weight = self.cast(label_weight, self.ms_type)
for j in range(self.num_layers):
begin = self.slice_index[j]

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -113,8 +113,6 @@ config = ed({
# LR
"base_lr": 0.02,
"base_step": 58633,
"total_epoch": 13,
"warmup_step": 500,
"warmup_ratio": 1/3.0,
"sgd_step": [8, 11],

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@ import numpy as np
from numpy import random
import mmcv
from mindspore import context
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from mindspore.mindrecord import FileWriter
@ -213,7 +214,7 @@ def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""imnormalize operation for image"""
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
img_data = mmcv.imnormalize(img, np.array([123.675, 116.28, 103.53]), np.array([58.395, 57.12, 57.375]), True)
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
@ -232,9 +233,14 @@ def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
"""transpose operation for image"""
img_data = img.transpose(2, 0, 1).copy()
img_data = img_data.astype(np.float16)
img_shape = img_shape.astype(np.float16)
gt_bboxes = gt_bboxes.astype(np.float16)
if context.get_context("device_target") == "Ascend":
img_data = img_data.astype(np.float16)
img_shape = img_shape.astype(np.float16)
gt_bboxes = gt_bboxes.astype(np.float16)
else:
img_data = img_data.astype(np.float32)
img_shape = img_shape.astype(np.float32)
gt_bboxes = gt_bboxes.astype(np.float32)
gt_label = gt_label.astype(np.int32)
gt_num = gt_num.astype(np.bool)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,12 +25,10 @@ def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, rank_size=1):
def dynamic_lr(config, steps_per_epoch):
"""dynamic learning rate generator"""
base_lr = config.base_lr
base_step = (config.base_step // rank_size) + rank_size
total_steps = int(base_step * config.total_epoch)
total_steps = steps_per_epoch * config.epoch_size
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,7 +20,7 @@ import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore import ParameterTuple, context
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
@ -167,7 +167,10 @@ class TrainOneStepCell(nn.Cell):
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
if context.get_context("device_target") == "Ascend":
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
else:
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
self.reduce_flag = reduce_flag
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,10 +19,11 @@ import os
import time
import argparse
import ast
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore import context, Tensor, Parameter
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
from mindspore.context import ParallelMode
@ -42,20 +43,30 @@ parser = argparse.ArgumentParser(description="FasterRcnn training")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset name, default: coco.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.")
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
if __name__ == '__main__':
if args_opt.run_distribute:
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
if args_opt.device_target == "Ascend":
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
init("nccl")
context.reset_auto_parallel_context()
rank = get_rank()
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
rank = 0
device_num = 1
@ -116,10 +127,14 @@ if __name__ == '__main__':
for item in list(param_dict.keys()):
if not item.startswith('backbone'):
param_dict.pop(item)
if args_opt.device_target == "GPU":
for key, value in param_dict.items():
tensor = value.asnumpy().astype(np.float32)
param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict)
loss = LossNet()
lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32)
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
@ -141,4 +156,4 @@ if __name__ == '__main__':
cb += [ckpoint_cb]
model = Model(net)
model.train(config.epoch_size, dataset, callbacks=cb)
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)

Loading…
Cancel
Save