|
|
|
@ -365,6 +365,7 @@ class YOLOV3DarkNet53(nn.Cell):
|
|
|
|
|
def __init__(self, is_training):
|
|
|
|
|
super(YOLOV3DarkNet53, self).__init__()
|
|
|
|
|
self.config = ConfigYOLOV3DarkNet53()
|
|
|
|
|
self.tenser_to_array = P.TupleToArray()
|
|
|
|
|
|
|
|
|
|
# YOLOv3 network
|
|
|
|
|
self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers,
|
|
|
|
@ -379,7 +380,9 @@ class YOLOV3DarkNet53(nn.Cell):
|
|
|
|
|
self.detect_2 = DetectionBlock('m', is_training=is_training)
|
|
|
|
|
self.detect_3 = DetectionBlock('s', is_training=is_training)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, input_shape):
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
input_shape = F.shape(x)[2:4]
|
|
|
|
|
input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
|
|
|
|
|
big_object_output, medium_object_output, small_object_output = self.feature_map(x)
|
|
|
|
|
output_big = self.detect_1(big_object_output, input_shape)
|
|
|
|
|
output_me = self.detect_2(medium_object_output, input_shape)
|
|
|
|
@ -394,12 +397,15 @@ class YoloWithLossCell(nn.Cell):
|
|
|
|
|
super(YoloWithLossCell, self).__init__()
|
|
|
|
|
self.yolo_network = network
|
|
|
|
|
self.config = ConfigYOLOV3DarkNet53()
|
|
|
|
|
self.tenser_to_array = P.TupleToArray()
|
|
|
|
|
self.loss_big = YoloLossBlock('l', self.config)
|
|
|
|
|
self.loss_me = YoloLossBlock('m', self.config)
|
|
|
|
|
self.loss_small = YoloLossBlock('s', self.config)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape):
|
|
|
|
|
yolo_out = self.yolo_network(x, input_shape)
|
|
|
|
|
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2):
|
|
|
|
|
input_shape = F.shape(x)[2:4]
|
|
|
|
|
input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
|
|
|
|
|
yolo_out = self.yolo_network(x)
|
|
|
|
|
loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
|
|
|
|
|
loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
|
|
|
|
|
loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)
|
|
|
|
|