ignore pred overlap gt > 0.7. test=develop

revert-15296-async_double_buffered_py_reader
dengkaipeng 7 years ago
parent bd6deb1a8b
commit e7e4f084e5

@ -35,12 +35,15 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto dim_gtlabel = ctx->GetInputDim("GTLabel");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors"); auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2; int anchor_num = anchors.size() / 2;
auto anchor_mask = ctx->Attrs().Get<std::vector<int>>("anchor_mask");
int mask_num = anchor_mask.size();
auto class_num = ctx->Attrs().Get<int>("class_num"); auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3],
"Input(X) dim[3] and dim[4] should be euqal."); "Input(X) dim[3] and dim[4] should be euqal.");
PADDLE_ENFORCE_EQ(dim_x[1], anchor_num * (5 + class_num), PADDLE_ENFORCE_EQ(
"Input(X) dim[1] should be equal to (anchor_number * (5 " dim_x[1], mask_num * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."); "+ class_num)).");
PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3,
"Input(GTBox) should be a 3-D tensor"); "Input(GTBox) should be a 3-D tensor");
@ -55,6 +58,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Attr(anchors) length should be greater then 0."); "Attr(anchors) length should be greater then 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
"Attr(anchors) length should be even integer."); "Attr(anchors) length should be even integer.");
for (size_t i = 0; i < anchor_mask.size(); i++) {
PADDLE_ENFORCE_LT(
anchor_mask[i], anchor_num,
"Attr(anchor_mask) should not crossover Attr(anchors).");
}
PADDLE_ENFORCE_GT(class_num, 0, PADDLE_ENFORCE_GT(class_num, 0,
"Attr(class_num) should be an integer greater then 0."); "Attr(class_num) should be an integer greater then 0.");
@ -74,7 +82,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"The input tensor of YOLO v3 loss operator, " "The input tensor of YOLOv3 loss operator, "
"This is a 4-D tensor with shape of [N, C, H, W]." "This is a 4-D tensor with shape of [N, C, H, W]."
"H and W should be same, and the second dimention(C) stores" "H and W should be same, and the second dimention(C) stores"
"box locations, confidence score and classification one-hot" "box locations, confidence score and classification one-hot"
@ -99,13 +107,20 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("class_num", "The number of classes to predict."); AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<std::vector<int>>("anchors", AddAttr<std::vector<int>>("anchors",
"The anchor width and height, " "The anchor width and height, "
"it will be parsed pair by pair."); "it will be parsed pair by pair.")
AddAttr<int>("input_size", .SetDefault(std::vector<int>{});
"The input size of YOLOv3 net, " AddAttr<std::vector<int>>("anchor_mask",
"generally this is set as 320, 416 or 608.") "The mask index of anchors used in "
.SetDefault(406); "current YOLOv3 loss calculation.")
.SetDefault(std::vector<int>{});
AddAttr<int>("downsample",
"The downsample ratio from network input to YOLOv3 loss "
"input, so 32, 16, 8 should be set for the first, second, "
"and thrid YOLOv3 loss operators.")
.SetDefault(32);
AddAttr<float>("ignore_thresh", AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss."); "The ignore threshold to ignore confidence loss.")
.SetDefault(0.7);
AddComment(R"DOC( AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground This operator generate yolov3 loss by given predict result and ground
truth boxes. truth boxes.

File diff suppressed because it is too large Load Diff

@ -413,9 +413,10 @@ def yolov3_loss(x,
gtbox, gtbox,
gtlabel, gtlabel,
anchors, anchors,
anchor_mask,
class_num, class_num,
ignore_thresh, ignore_thresh,
input_size, downsample,
name=None): name=None):
""" """
${comment} ${comment}
@ -430,9 +431,10 @@ def yolov3_loss(x,
gtlabel (Variable): class id of ground truth boxes, shoud be ins shape gtlabel (Variable): class id of ground truth boxes, shoud be ins shape
of [N, B]. of [N, B].
anchors (list|tuple): ${anchors_comment} anchors (list|tuple): ${anchors_comment}
anchor_mask (list|tuple): ${anchor_mask_comment}
class_num (int): ${class_num_comment} class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment} ignore_thresh (float): ${ignore_thresh_comment}
input_size (int): ${input_size_comment} downsample (int): ${downsample_comment}
name (string): the name of yolov3 loss name (string): the name of yolov3 loss
Returns: Returns:
@ -452,7 +454,8 @@ def yolov3_loss(x,
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32')
anchors = [10, 13, 16, 30, 33, 23] anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
anchors = [0, 1, 2]
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80
anchors=anchors, ignore_thresh=0.5) anchors=anchors, ignore_thresh=0.5)
""" """
@ -466,6 +469,8 @@ def yolov3_loss(x,
raise TypeError("Input gtlabel of yolov3_loss must be Variable") raise TypeError("Input gtlabel of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple): if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple") raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple):
raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple")
if not isinstance(class_num, int): if not isinstance(class_num, int):
raise TypeError("Attr class_num of yolov3_loss must be an integer") raise TypeError("Attr class_num of yolov3_loss must be an integer")
if not isinstance(ignore_thresh, float): if not isinstance(ignore_thresh, float):
@ -480,9 +485,10 @@ def yolov3_loss(x,
attrs = { attrs = {
"anchors": anchors, "anchors": anchors,
"anchor_mask": anchor_mask,
"class_num": class_num, "class_num": class_num,
"ignore_thresh": ignore_thresh, "ignore_thresh": ignore_thresh,
"input_size": input_size, "downsample": downsample,
} }
helper.append_op( helper.append_op(

@ -463,8 +463,8 @@ class TestYoloDetection(unittest.TestCase):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32')
gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13],
0.7, 416) [0, 1], 10, 0.7, 32)
self.assertIsNotNone(loss) self.assertIsNotNone(loss)

@ -22,32 +22,42 @@ from op_test import OpTest
from paddle.fluid import core from paddle.fluid import core
# def l1loss(x, y, weight):
def l1loss(x, y, weight): # n = x.shape[0]
n = x.shape[0] # x = x.reshape((n, -1))
x = x.reshape((n, -1)) # y = y.reshape((n, -1))
y = y.reshape((n, -1)) # weight = weight.reshape((n, -1))
weight = weight.reshape((n, -1)) # return (np.abs(y - x) * weight).sum(axis=1)
return (np.abs(y - x) * weight).sum(axis=1) #
#
# def mse(x, y, weight):
# n = x.shape[0]
# x = x.reshape((n, -1))
# y = y.reshape((n, -1))
# weight = weight.reshape((n, -1))
# return ((y - x)**2 * weight).sum(axis=1)
#
#
# def sce(x, label, weight):
# n = x.shape[0]
# x = x.reshape((n, -1))
# label = label.reshape((n, -1))
# weight = weight.reshape((n, -1))
# sigmoid_x = expit(x)
# term1 = label * np.log(sigmoid_x)
# term2 = (1.0 - label) * np.log(1.0 - sigmoid_x)
# return ((-term1 - term2) * weight).sum(axis=1)
def mse(x, y, weight): def l1loss(x, y):
n = x.shape[0] return abs(x - y)
x = x.reshape((n, -1))
y = y.reshape((n, -1))
weight = weight.reshape((n, -1))
return ((y - x)**2 * weight).sum(axis=1)
def sce(x, label, weight): def sce(x, label):
n = x.shape[0]
x = x.reshape((n, -1))
label = label.reshape((n, -1))
weight = weight.reshape((n, -1))
sigmoid_x = expit(x) sigmoid_x = expit(x)
term1 = label * np.log(sigmoid_x) term1 = label * np.log(sigmoid_x)
term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) term2 = (1.0 - label) * np.log(1.0 - sigmoid_x)
return ((-term1 - term2) * weight).sum(axis=1) return -term1 - term2
def box_iou(box1, box2): def box_iou(box1, box2):
@ -160,6 +170,121 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs):
return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
def batch_xywh_box_iou(box1, box2):
b1_left = box1[:, :, 0] - box1[:, :, 2] / 2
b1_right = box1[:, :, 0] + box1[:, :, 2] / 2
b1_top = box1[:, :, 1] - box1[:, :, 3] / 2
b1_bottom = box1[:, :, 1] + box1[:, :, 3] / 2
b2_left = box2[:, :, 0] - box2[:, :, 2] / 2
b2_right = box2[:, :, 0] + box2[:, :, 2] / 2
b2_top = box2[:, :, 1] - box2[:, :, 3] / 2
b2_bottom = box2[:, :, 1] + box2[:, :, 3] / 2
left = np.maximum(b1_left[:, :, np.newaxis], b2_left[:, np.newaxis, :])
right = np.minimum(b1_right[:, :, np.newaxis], b2_right[:, np.newaxis, :])
top = np.maximum(b1_top[:, :, np.newaxis], b2_top[:, np.newaxis, :])
bottom = np.minimum(b1_bottom[:, :, np.newaxis],
b2_bottom[:, np.newaxis, :])
inter_w = np.clip(right - left, 0., 1.)
inter_h = np.clip(bottom - top, 0., 1.)
inter_area = inter_w * inter_h
b1_area = (b1_right - b1_left) * (b1_bottom - b1_top)
b2_area = (b2_right - b2_left) * (b2_bottom - b2_top)
union = b1_area[:, :, np.newaxis] + b2_area[:, np.newaxis, :] - inter_area
return inter_area / union
def YOLOv3Loss(x, gtbox, gtlabel, attrs):
n, c, h, w = x.shape
b = gtbox.shape[1]
anchors = attrs['anchors']
an_num = len(anchors) // 2
anchor_mask = attrs['anchor_mask']
mask_num = len(anchor_mask)
class_num = attrs["class_num"]
ignore_thresh = attrs['ignore_thresh']
downsample = attrs['downsample']
input_size = downsample * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float32')
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
mask_anchors = []
for m in anchor_mask:
mask_anchors.append((anchors[2 * m], anchors[2 * m + 1]))
anchors_s = np.array(
[(an_w / input_size, an_h / input_size) for an_w, an_h in mask_anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, mask_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, mask_num, 1, 1))
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
pred_box = pred_box.reshape((n, -1, 4))
pred_obj = x[:, :, :, :, 4].reshape((n, -1))
objness = np.zeros(pred_box.shape[:2])
ious = batch_xywh_box_iou(pred_box, gtbox)
ious_max = np.max(ious, axis=-1)
objness = np.where(ious_max > ignore_thresh, -np.ones_like(objness),
objness)
gtbox_shift = gtbox.copy()
gtbox_shift[:, :, 0] = 0
gtbox_shift[:, :, 1] = 0
anchors = [(anchors[2 * i], anchors[2 * i + 1]) for i in range(0, an_num)]
anchors_s = np.array(
[(an_w / input_size, an_h / input_size) for an_w, an_h in anchors])
anchor_boxes = np.concatenate(
[np.zeros_like(anchors_s), anchors_s], axis=-1)
anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1))
ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes)
iou_matches = np.argmax(ious, axis=-1)
for i in range(n):
for j in range(b):
if gtbox[i, j, 2:].sum() == 0:
continue
if iou_matches[i, j] not in anchor_mask:
continue
an_idx = anchor_mask.index(iou_matches[i, j])
gi = int(gtbox[i, j, 0] * w)
gj = int(gtbox[i, j, 1] * h)
tx = gtbox[i, j, 0] * w - gi
ty = gtbox[i, j, 1] * w - gj
tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0])
th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1])
scale = 2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]
loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale
loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale
loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale
loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale
objness[i, an_idx * h * w + gj * w + gi] = 1
for label_idx in range(class_num):
loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx],
int(label_idx == gtlabel[i, j]))
for j in range(mask_num * h * w):
if objness[i, j] >= 0:
loss[i] += sce(pred_obj[i, j], objness[i, j])
return loss
class TestYolov3LossOp(OpTest): class TestYolov3LossOp(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
@ -171,13 +296,14 @@ class TestYolov3LossOp(OpTest):
self.attrs = { self.attrs = {
"anchors": self.anchors, "anchors": self.anchors,
"anchor_mask": self.anchor_mask,
"class_num": self.class_num, "class_num": self.class_num,
"ignore_thresh": self.ignore_thresh, "ignore_thresh": self.ignore_thresh,
"input_size": self.input_size, "downsample": self.downsample,
} }
self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel}
self.outputs = {'Loss': YoloV3Loss(x, gtbox, gtlabel, self.attrs)} self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)}
def test_check_output(self): def test_check_output(self):
place = core.CPUPlace() place = core.CPUPlace()
@ -189,15 +315,19 @@ class TestYolov3LossOp(OpTest):
place, ['X'], place, ['X'],
'Loss', 'Loss',
no_grad_set=set(["GTBox", "GTLabel"]), no_grad_set=set(["GTBox", "GTLabel"]),
max_relative_error=0.31) max_relative_error=0.15)
def initTestCase(self): def initTestCase(self):
self.anchors = [12, 12] self.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5 self.class_num = 5
self.ignore_thresh = 0.5 self.ignore_thresh = 0.7
self.input_size = 416 self.downsample = 32
self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3) self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (1, 5, 4) self.gtbox_shape = (3, 10, 4)
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save