use stable Sigmoid Cross Entropy implement. test=develop

revert-15296-async_double_buffered_py_reader
dengkaipeng 6 years ago
parent 245b1f0579
commit 192d293854

@ -99,6 +99,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("anchors",
"The anchor width and height, "
"it will be parsed pair by pair.");
AddAttr<int>("input_size",
"The input size of YOLOv3 net, "
"generally this is set as 320, 416 or 608.")
.SetDefault(406);
AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.");
AddAttr<float>("loss_weight_xy", "The weight of x, y location loss.")

File diff suppressed because it is too large Load Diff

@ -415,6 +415,7 @@ def yolov3_loss(x,
anchors,
class_num,
ignore_thresh,
input_size,
loss_weight_xy=None,
loss_weight_wh=None,
loss_weight_conf_target=None,
@ -436,6 +437,7 @@ def yolov3_loss(x,
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment}
input_size (int): ${input_size_comment}
loss_weight_xy (float|None): ${loss_weight_xy_comment}
loss_weight_wh (float|None): ${loss_weight_wh_comment}
loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment}
@ -490,6 +492,7 @@ def yolov3_loss(x,
"anchors": anchors,
"class_num": class_num,
"ignore_thresh": ignore_thresh,
"input_size": input_size,
}
if loss_weight_xy is not None and isinstance(loss_weight_xy, float):

@ -464,7 +464,7 @@ class TestYoloDetection(unittest.TestCase):
gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32')
gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10,
0.5)
0.7, 416)
self.assertIsNotNone(loss)

@ -16,31 +16,22 @@ from __future__ import division
import unittest
import numpy as np
from scipy.special import logit
from scipy.special import expit
from op_test import OpTest
from paddle.fluid import core
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
def mse(x, y, weight, num):
return ((y - x)**2 * weight).sum() / num
def mse(x, y, num):
return ((y - x)**2).sum() / num
def bce(x, y, mask):
x = x.reshape((-1))
y = y.reshape((-1))
mask = mask.reshape((-1))
error_sum = 0.0
count = 0
for i in range(x.shape[0]):
if mask[i] > 0:
error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i])
count += 1
return error_sum / (-1.0 * count)
def sce(x, label, weight, num):
sigmoid_x = expit(x)
term1 = label * np.log(sigmoid_x)
term2 = (1.0 - label) * np.log(1.0 - sigmoid_x)
return ((-term1 - term2) * weight).sum() / num
def box_iou(box1, box2):
@ -66,11 +57,12 @@ def box_iou(box1, box2):
return inter_area / (b1_area + b2_area + inter_area)
def build_target(gtboxs, gtlabel, attrs, grid_size):
n, b, _ = gtboxs.shape
def build_target(gtboxes, gtlabel, attrs, grid_size):
n, b, _ = gtboxes.shape
ignore_thresh = attrs["ignore_thresh"]
anchors = attrs["anchors"]
class_num = attrs["class_num"]
input_size = attrs["input_size"]
an_num = len(anchors) // 2
obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32')
@ -78,20 +70,21 @@ def build_target(gtboxs, gtlabel, attrs, grid_size):
ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tweight = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
tcls = np.zeros(
(n, an_num, grid_size, grid_size, class_num)).astype('float32')
for i in range(n):
for j in range(b):
if gtboxs[i, j, :].sum() == 0:
if gtboxes[i, j, :].sum() == 0:
continue
gt_label = gtlabel[i, j]
gx = gtboxs[i, j, 0] * grid_size
gy = gtboxs[i, j, 1] * grid_size
gw = gtboxs[i, j, 2] * grid_size
gh = gtboxs[i, j, 3] * grid_size
gx = gtboxes[i, j, 0] * grid_size
gy = gtboxes[i, j, 1] * grid_size
gw = gtboxes[i, j, 2] * input_size
gh = gtboxes[i, j, 3] * input_size
gi = int(gx)
gj = int(gy)
@ -115,10 +108,12 @@ def build_target(gtboxs, gtlabel, attrs, grid_size):
best_an_index])
th[i, best_an_index, gj, gi] = np.log(
gh / anchors[2 * best_an_index + 1])
tweight[i, best_an_index, gj, gi] = 2.0 - gtboxes[
i, j, 2] * gtboxes[i, j, 3]
tconf[i, best_an_index, gj, gi] = 1
tcls[i, best_an_index, gj, gi, gt_label] = 1
return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask)
return (tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask)
def YoloV3Loss(x, gtbox, gtlabel, attrs):
@ -126,27 +121,28 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs):
an_num = len(attrs['anchors']) // 2
class_num = attrs["class_num"]
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
pred_x = sigmoid(x[:, :, :, :, 0])
pred_y = sigmoid(x[:, :, :, :, 1])
pred_x = x[:, :, :, :, 0]
pred_y = x[:, :, :, :, 1]
pred_w = x[:, :, :, :, 2]
pred_h = x[:, :, :, :, 3]
pred_conf = sigmoid(x[:, :, :, :, 4])
pred_cls = sigmoid(x[:, :, :, :, 5:])
pred_conf = x[:, :, :, :, 4]
pred_cls = x[:, :, :, :, 5:]
tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target(
tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target(
gtbox, gtlabel, attrs, x.shape[2])
obj_weight = obj_mask * tweight
obj_mask_expand = np.tile(
np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num'])))
loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum())
loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum())
loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum())
loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum())
loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask)
loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask,
noobj_mask)
loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand,
obj_mask_expand)
box_f = an_num * h * w
class_f = an_num * h * w * class_num
loss_x = sce(pred_x, tx, obj_weight, box_f)
loss_y = sce(pred_y, ty, obj_weight, box_f)
loss_w = mse(pred_w, tw, obj_weight, box_f)
loss_h = mse(pred_h, th, obj_weight, box_f)
loss_conf_target = sce(pred_conf, tconf, obj_mask, box_f)
loss_conf_notarget = sce(pred_conf, tconf, noobj_mask, box_f)
loss_class = sce(pred_cls, tcls, obj_mask_expand, class_f)
return attrs['loss_weight_xy'] * (loss_x + loss_y) \
+ attrs['loss_weight_wh'] * (loss_w + loss_h) \
@ -164,7 +160,7 @@ class TestYolov3LossOp(OpTest):
self.loss_weight_class = 1.0
self.initTestCase()
self.op_type = 'yolov3_loss'
x = np.random.random(size=self.x_shape).astype('float32')
x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32'))
gtbox = np.random.random(size=self.gtbox_shape).astype('float32')
gtlabel = np.random.randint(0, self.class_num,
self.gtbox_shape[:2]).astype('int32')
@ -173,6 +169,7 @@ class TestYolov3LossOp(OpTest):
"anchors": self.anchors,
"class_num": self.class_num,
"ignore_thresh": self.ignore_thresh,
"input_size": self.input_size,
"loss_weight_xy": self.loss_weight_xy,
"loss_weight_wh": self.loss_weight_wh,
"loss_weight_conf_target": self.loss_weight_conf_target,
@ -196,18 +193,19 @@ class TestYolov3LossOp(OpTest):
place, ['X'],
'Loss',
no_grad_set=set(["GTBox", "GTLabel"]),
max_relative_error=0.06)
max_relative_error=0.3)
def initTestCase(self):
self.anchors = [10, 13, 12, 12]
self.class_num = 10
self.ignore_thresh = 0.5
self.ignore_thresh = 0.7
self.input_size = 416
self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7)
self.gtbox_shape = (5, 10, 4)
self.loss_weight_xy = 2.5
self.loss_weight_xy = 1.4
self.loss_weight_wh = 0.8
self.loss_weight_conf_target = 1.5
self.loss_weight_conf_notarget = 0.5
self.loss_weight_conf_target = 1.1
self.loss_weight_conf_notarget = 0.9
self.loss_weight_class = 1.2

Loading…
Cancel
Save