diff --git a/model_zoo/official/cv/faster_rcnn/eval.py b/model_zoo/official/cv/faster_rcnn/eval.py index 0316a82d58..7a778a34c4 100644 --- a/model_zoo/official/cv/faster_rcnn/eval.py +++ b/model_zoo/official/cv/faster_rcnn/eval.py @@ -19,6 +19,7 @@ import argparse import time import numpy as np from pycocotools.coco import COCO +import mindspore.common.dtype as mstype from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed, Parameter @@ -51,7 +52,11 @@ def fasterrcnn_eval(dataset_path, ckpt_path, ann_file): tensor = value.asnumpy().astype(np.float32) param_dict[key] = Parameter(tensor, key) load_param_into_net(net, param_dict) + net.set_train(False) + device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" + if device_type == "Ascend": + net.to_float(mstype.float16) eval_iter = 0 total = ds.get_dataset_size() diff --git a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py index 792cce4de6..765f13d6bc 100644 --- a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py +++ b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py @@ -16,6 +16,7 @@ import numpy as np import mindspore.nn as nn +from mindspore import context from mindspore.ops import operations as P from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype @@ -144,6 +145,7 @@ class Faster_Rcnn_Resnet50(nn.Cell): # Init tensor self.init_tensor(config) + self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" def roi_init(self, config): self.roi_align = SingleRoIExtractor(config, @@ -267,6 +269,8 @@ class Faster_Rcnn_Resnet50(nn.Cell): bboxes_all = self.concat(bboxes_tuple) else: bboxes_all = bboxes_tuple[0] + if self.device_type == "Ascend": + bboxes_all = self.cast(bboxes_all, mstype.float16) rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) rois = self.cast(rois, mstype.float32) diff --git a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py index 75cacda893..4eade1f188 100644 --- a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py +++ b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py @@ -40,7 +40,7 @@ class DenseNoTranpose(nn.Cell): if self.device_type == "Ascend": x = self.cast(x, mstype.float16) weight = self.cast(self.weight, mstype.float16) - output = self.bias_add(self.cast(self.matmul(x, weight), mstype.float32), self.bias) + output = self.bias_add(self.matmul(x, weight), self.bias) else: output = self.bias_add(self.matmul(x, self.weight), self.bias) return output diff --git a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py index 5a2bfc5c55..52d12a30cf 100644 --- a/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py +++ b/model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py @@ -16,7 +16,7 @@ import numpy as np import mindspore.nn as nn import mindspore.common.dtype as mstype -from mindspore import Tensor +from mindspore import context, Tensor from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.common.initializer import initializer @@ -102,6 +102,7 @@ class RPN(nn.Cell): cfg_rpn = config self.dtype = np.float32 self.ms_type = mstype.float32 + self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" self.num_bboxes = cfg_rpn.num_bboxes self.slice_index = () self.feature_anchor_shape = () @@ -180,9 +181,12 @@ class RPN(nn.Cell): bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor() for i in range(num_layers): - rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ + rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ weight_conv, bias_conv, weight_cls, \ - bias_cls, weight_reg, bias_reg)) + bias_cls, weight_reg, bias_reg) + if self.device_type == "Ascend": + rpn_reg_cls_block.to_float(mstype.float16) + rpn_layer.append(rpn_reg_cls_block) for i in range(1, num_layers): rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight @@ -250,6 +254,7 @@ class RPN(nn.Cell): mstype.bool_), anchor_using_list, gt_valids_i) + bbox_target = self.cast(bbox_target, self.ms_type) bbox_weight = self.cast(bbox_weight, self.ms_type) label = self.cast(label, self.ms_type) label_weight = self.cast(label_weight, self.ms_type) @@ -286,8 +291,8 @@ class RPN(nn.Cell): label_ = F.stop_gradient(label_with_batchsize) label_weight_ = F.stop_gradient(label_weight_with_batchsize) - cls_score_i = rpn_cls_score[i] - reg_score_i = rpn_bbox_pred[i] + cls_score_i = self.cast(rpn_cls_score[i], self.ms_type) + reg_score_i = self.cast(rpn_bbox_pred[i], self.ms_type) loss_cls = self.loss_cls(cls_score_i, label_) loss_cls_item = loss_cls * label_weight_ diff --git a/model_zoo/official/cv/faster_rcnn/train.py b/model_zoo/official/cv/faster_rcnn/train.py index 5eb9817c01..45729f282c 100644 --- a/model_zoo/official/cv/faster_rcnn/train.py +++ b/model_zoo/official/cv/faster_rcnn/train.py @@ -152,6 +152,10 @@ if __name__ == '__main__': param_dict[key] = Parameter(tensor, key) load_param_into_net(net, param_dict) + device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" + if device_type == "Ascend": + net.to_float(mstype.float16) + loss = LossNet() lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)