From 192d293854b93d86bbb27ed37af199dd6e4ee1c6 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 6 Dec 2018 19:53:41 +0800 Subject: [PATCH 01/24] use stable Sigmoid Cross Entropy implement. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 4 + paddle/fluid/operators/yolov3_loss_op.h | 283 ++++++++++-------- python/paddle/fluid/layers/detection.py | 3 + python/paddle/fluid/tests/test_detection.py | 2 +- .../tests/unittests/test_yolov3_loss_op.py | 90 +++--- 5 files changed, 208 insertions(+), 174 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 60508f7ab8..66d618de59 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -99,6 +99,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("anchors", "The anchor width and height, " "it will be parsed pair by pair."); + AddAttr("input_size", + "The input size of YOLOv3 net, " + "generally this is set as 320, 416 or 608.") + .SetDefault(406); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); AddAttr("loss_weight_xy", "The weight of x, y location loss.") diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 0bb285722d..fac06b4204 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -33,87 +33,91 @@ static inline bool isZero(T x) { } template -static inline T sigmoid(T x) { - return 1.0 / (exp(-1.0 * x) + 1.0); -} +static inline T CalcMSEWithWeight(const Tensor& x, const Tensor& y, + const Tensor& weight, const T mf) { + int numel = static_cast(x.numel()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); -template -static inline T CalcMaskPointNum(const Tensor& mask) { - auto mask_t = EigenVector::Flatten(mask); - T count = 0.0; - for (int i = 0; i < mask_t.dimensions()[0]; i++) { - if (mask_t(i)) { - count += 1.0; - } + T error_sum = 0.0; + for (int i = 0; i < numel; i++) { + T xi = x_data[i]; + T yi = y_data[i]; + T weighti = weight_data[i]; + error_sum += pow(yi - xi, 2) * weighti; } - return count; + + return error_sum / mf; } template -static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, - const Tensor& mask) { - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - T error_sum = 0.0; - T points = 0.0; - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - error_sum += pow(x_t(i) - y_t(i), 2); - points += 1; - } +static void CalcMSEGradWithWeight(Tensor* grad, const Tensor& x, + const Tensor& y, const Tensor& weight, + const T mf) { + int numel = static_cast(grad->numel()); + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < numel; i++) { + grad_data[i] = 2.0 * weight_data[i] * (x_data[i] - y_data[i]) / mf; } - return (error_sum / points); } template -static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y, - const Tensor& mask, T mf) { - auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf; - } +struct SigmoidCrossEntropyForward { + T operator()(const T& x, const T& label) const { + T term1 = (x > 0) ? x : 0; + T term2 = x * label; + T term3 = std::log(static_cast(1.0) + std::exp(-(std::abs(x)))); + return term1 - term2 + term3; } -} +}; template -static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, - const Tensor& mask) { - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); +struct SigmoidCrossEntropyBackward { + T operator()(const T& x, const T& label) const { + T sigmoid_x = + static_cast(1.0) / (static_cast(1.0) + std::exp(-1.0 * x)); + return sigmoid_x - label; + } +}; - T error_sum = 0.0; - T points = 0.0; - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - error_sum += - -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i))); - points += 1; - } +template +static inline T CalcSCEWithWeight(const Tensor& x, const Tensor& labels, + const Tensor& weight, const T mf) { + int numel = x.numel(); + const T* x_data = x.data(); + const T* labels_data = labels.data(); + const T* weight_data = weight.data(); + + T loss = 0.0; + for (int i = 0; i < numel; i++) { + T xi = x_data[i]; + T labeli = labels_data[i]; + T weighti = weight_data[i]; + loss += ((xi > 0.0 ? xi : 0.0) - xi * labeli + + std::log(1.0 + std::exp(-1.0 * std::abs(xi)))) * + weighti; } - return (error_sum / points); + return loss / mf; } template -static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x, - const Tensor& y, const Tensor& mask, - T mf) { - auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf; - } +static inline void CalcSCEGradWithWeight(Tensor* grad, const Tensor& x, + const Tensor& labels, + const Tensor& weight, const T mf) { + int numel = grad->numel(); + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* labels_data = labels.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < numel; i++) { + grad_data[i] = (1.0 / (1.0 + std::exp(-1.0 * x_data[i])) - labels_data[i]) * + weight_data[i] / mf; } } @@ -139,21 +143,20 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_conf, for (int an_idx = 0; an_idx < anchor_num; an_idx++) { for (int j = 0; j < h; j++) { for (int k = 0; k < w; k++) { - pred_x_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx, j, k)); + pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, k); pred_y_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k)); + input_t(i, box_attr_num * an_idx + 1, j, k); pred_w_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 2, j, k); pred_h_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 3, j, k); pred_conf_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k)); + input_t(i, box_attr_num * an_idx + 4, j, k); for (int c = 0; c < class_num; c++) { pred_class_t(i, an_idx, j, k, c) = - sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + input_t(i, box_attr_num * an_idx + 5 + c, j, k); } } } @@ -188,21 +191,22 @@ static T CalcBoxIoU(std::vector box1, std::vector box2) { template static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, const float ignore_thresh, std::vector anchors, - const int grid_size, Tensor* obj_mask, - Tensor* noobj_mask, Tensor* tx, Tensor* ty, - Tensor* tw, Tensor* th, Tensor* tconf, - Tensor* tclass) { + const int input_size, const int grid_size, + Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, + Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight, + Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; const int b = gt_box.dims()[1]; const int anchor_num = anchors.size() / 2; auto gt_box_t = EigenTensor::From(gt_box); auto gt_label_t = EigenTensor::From(gt_label); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); - auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); + auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); auto tx_t = EigenTensor::From(*tx).setConstant(0.0); auto ty_t = EigenTensor::From(*ty).setConstant(0.0); auto tw_t = EigenTensor::From(*tw).setConstant(0.0); auto th_t = EigenTensor::From(*th).setConstant(0.0); + auto tweight_t = EigenTensor::From(*tweight).setConstant(0.0); auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); @@ -216,8 +220,8 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, int cur_label = gt_label_t(i, j); T gx = gt_box_t(i, j, 0) * grid_size; T gy = gt_box_t(i, j, 1) * grid_size; - T gw = gt_box_t(i, j, 2) * grid_size; - T gh = gt_box_t(i, j, 3) * grid_size; + T gw = gt_box_t(i, j, 2) * input_size; + T gh = gt_box_t(i, j, 3) * input_size; int gi = static_cast(gx); int gj = static_cast(gy); @@ -234,15 +238,17 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - noobj_mask_t(i, an_idx, gj, gi) = 0; + noobj_mask_t(i, an_idx, gj, gi) = static_cast(0.0); } } - obj_mask_t(i, best_an_index, gj, gi) = 1; - noobj_mask_t(i, best_an_index, gj, gi) = 0; + obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); + noobj_mask_t(i, best_an_index, gj, gi) = static_cast(0.0); tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); + tweight_t(i, best_an_index, gj, gi) = + 2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3); tclass_t(i, best_an_index, gj, gi, cur_label) = 1; tconf_t(i, best_an_index, gj, gi) = 1; } @@ -295,27 +301,22 @@ static void AddAllGradToInputGrad( for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { grad_t(i, j * attr_num, k, l) = - grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) * - (1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy; + grad_x_t(i, j, k, l) * loss * loss_weight_xy; grad_t(i, j * attr_num + 1, k, l) = - grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * - (1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy; + grad_y_t(i, j, k, l) * loss * loss_weight_xy; grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 4, k, l) = - grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target; + grad_conf_target_t(i, j, k, l) * loss * loss_weight_conf_target; grad_t(i, j * attr_num + 4, k, l) += - grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * + grad_conf_notarget_t(i, j, k, l) * loss * loss_weight_conf_notarget; for (int c = 0; c < class_num; c++) { grad_t(i, j * attr_num + 5 + c, k, l) = - grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * - (1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class; + grad_class_t(i, j, k, l, c) * loss * loss_weight_class; } } } @@ -333,6 +334,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto* loss = ctx.Output("Loss"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); + int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); @@ -358,30 +360,46 @@ class Yolov3LossKernel : public framework::OpKernel { &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; - Tensor tx, ty, tw, th, tconf, tclass; - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + Tensor tx, ty, tw, th, tweight, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, - &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, + h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + &tconf, &tclass); + + Tensor obj_weight; + obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + auto obj_weight_t = EigenTensor::From(obj_weight); + auto obj_mask_t = EigenTensor::From(obj_mask); + auto tweight_t = EigenTensor::From(tweight); + obj_weight_t = obj_mask_t * tweight_t; Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); - - T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); - T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); - T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); - T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_target = CalcBCEWithMask(pred_conf, tconf, obj_mask); - T loss_conf_notarget = CalcBCEWithMask(pred_conf, tconf, noobj_mask); - T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); + + T box_f = static_cast(an_num * h * w); + T class_f = static_cast(an_num * h * w * class_num); + T loss_x = CalcSCEWithWeight(pred_x, tx, obj_weight, box_f); + T loss_y = CalcSCEWithWeight(pred_y, ty, obj_weight, box_f); + T loss_w = CalcMSEWithWeight(pred_w, tw, obj_weight, box_f); + T loss_h = CalcMSEWithWeight(pred_h, th, obj_weight, box_f); + T loss_conf_target = + CalcSCEWithWeight(pred_conf, tconf, obj_mask, box_f); + T loss_conf_notarget = + CalcSCEWithWeight(pred_conf, tconf, noobj_mask, box_f); + T loss_class = + CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, class_f); auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); loss_data[0] = loss_weight_xy * (loss_x + loss_y) + @@ -405,6 +423,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Loss")); const T loss = output_grad->data()[0]; + int input_size = ctx.Attr("input_size"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); @@ -430,22 +449,33 @@ class Yolov3LossGradKernel : public framework::OpKernel { &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; - Tensor tx, ty, tw, th, tconf, tclass; - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + Tensor tx, ty, tw, th, tweight, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, - &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, + h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + &tconf, &tclass); + + Tensor obj_weight; + obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + auto obj_weight_t = EigenTensor::From(obj_weight); + auto obj_mask_t = EigenTensor::From(obj_mask); + auto tweight_t = EigenTensor::From(tweight); + obj_weight_t = obj_mask_t * tweight_t; Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); Tensor grad_x, grad_y, grad_w, grad_h; Tensor grad_conf_target, grad_conf_notarget, grad_class; @@ -456,19 +486,18 @@ class Yolov3LossGradKernel : public framework::OpKernel { grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - T obj_mf = CalcMaskPointNum(obj_mask); - T noobj_mf = CalcMaskPointNum(noobj_mask); - T obj_expand_mf = CalcMaskPointNum(obj_mask_expand); - CalcMSEGradWithMask(&grad_x, pred_x, tx, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); - CalcBCEGradWithMask(&grad_conf_target, pred_conf, tconf, obj_mask, - obj_mf); - CalcBCEGradWithMask(&grad_conf_notarget, pred_conf, tconf, noobj_mask, - noobj_mf); - CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, - obj_expand_mf); + T box_f = static_cast(an_num * h * w); + T class_f = static_cast(an_num * h * w * class_num); + CalcSCEGradWithWeight(&grad_x, pred_x, tx, obj_weight, box_f); + CalcSCEGradWithWeight(&grad_y, pred_y, ty, obj_weight, box_f); + CalcMSEGradWithWeight(&grad_w, pred_w, tw, obj_weight, box_f); + CalcMSEGradWithWeight(&grad_h, pred_h, th, obj_weight, box_f); + CalcSCEGradWithWeight(&grad_conf_target, pred_conf, tconf, obj_mask, + box_f); + CalcSCEGradWithWeight(&grad_conf_notarget, pred_conf, tconf, noobj_mask, + box_f); + CalcSCEGradWithWeight(&grad_class, pred_class, tclass, obj_mask_expand, + class_f); input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); AddAllGradToInputGrad( diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 7cf575d253..5fb4588e0b 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -415,6 +415,7 @@ def yolov3_loss(x, anchors, class_num, ignore_thresh, + input_size, loss_weight_xy=None, loss_weight_wh=None, loss_weight_conf_target=None, @@ -436,6 +437,7 @@ def yolov3_loss(x, anchors (list|tuple): ${anchors_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} + input_size (int): ${input_size_comment} loss_weight_xy (float|None): ${loss_weight_xy_comment} loss_weight_wh (float|None): ${loss_weight_wh_comment} loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} @@ -490,6 +492,7 @@ def yolov3_loss(x, "anchors": anchors, "class_num": class_num, "ignore_thresh": ignore_thresh, + "input_size": input_size, } if loss_weight_xy is not None and isinstance(loss_weight_xy, float): diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 8723d9842a..7d75562900 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -464,7 +464,7 @@ class TestYoloDetection(unittest.TestCase): gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, - 0.5) + 0.7, 416) self.assertIsNotNone(loss) diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 544fe4b4f8..07e7155bbf 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -16,31 +16,22 @@ from __future__ import division import unittest import numpy as np +from scipy.special import logit +from scipy.special import expit from op_test import OpTest from paddle.fluid import core -def sigmoid(x): - return 1.0 / (1.0 + np.exp(-1.0 * x)) +def mse(x, y, weight, num): + return ((y - x)**2 * weight).sum() / num -def mse(x, y, num): - return ((y - x)**2).sum() / num - - -def bce(x, y, mask): - x = x.reshape((-1)) - y = y.reshape((-1)) - mask = mask.reshape((-1)) - - error_sum = 0.0 - count = 0 - for i in range(x.shape[0]): - if mask[i] > 0: - error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i]) - count += 1 - return error_sum / (-1.0 * count) +def sce(x, label, weight, num): + sigmoid_x = expit(x) + term1 = label * np.log(sigmoid_x) + term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) + return ((-term1 - term2) * weight).sum() / num def box_iou(box1, box2): @@ -66,11 +57,12 @@ def box_iou(box1, box2): return inter_area / (b1_area + b2_area + inter_area) -def build_target(gtboxs, gtlabel, attrs, grid_size): - n, b, _ = gtboxs.shape +def build_target(gtboxes, gtlabel, attrs, grid_size): + n, b, _ = gtboxes.shape ignore_thresh = attrs["ignore_thresh"] anchors = attrs["anchors"] class_num = attrs["class_num"] + input_size = attrs["input_size"] an_num = len(anchors) // 2 obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') @@ -78,20 +70,21 @@ def build_target(gtboxs, gtlabel, attrs, grid_size): ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tweight = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tcls = np.zeros( (n, an_num, grid_size, grid_size, class_num)).astype('float32') for i in range(n): for j in range(b): - if gtboxs[i, j, :].sum() == 0: + if gtboxes[i, j, :].sum() == 0: continue gt_label = gtlabel[i, j] - gx = gtboxs[i, j, 0] * grid_size - gy = gtboxs[i, j, 1] * grid_size - gw = gtboxs[i, j, 2] * grid_size - gh = gtboxs[i, j, 3] * grid_size + gx = gtboxes[i, j, 0] * grid_size + gy = gtboxes[i, j, 1] * grid_size + gw = gtboxes[i, j, 2] * input_size + gh = gtboxes[i, j, 3] * input_size gi = int(gx) gj = int(gy) @@ -115,10 +108,12 @@ def build_target(gtboxs, gtlabel, attrs, grid_size): best_an_index]) th[i, best_an_index, gj, gi] = np.log( gh / anchors[2 * best_an_index + 1]) + tweight[i, best_an_index, gj, gi] = 2.0 - gtboxes[ + i, j, 2] * gtboxes[i, j, 3] tconf[i, best_an_index, gj, gi] = 1 tcls[i, best_an_index, gj, gi, gt_label] = 1 - return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) + return (tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask) def YoloV3Loss(x, gtbox, gtlabel, attrs): @@ -126,27 +121,28 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): an_num = len(attrs['anchors']) // 2 class_num = attrs["class_num"] x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) - pred_x = sigmoid(x[:, :, :, :, 0]) - pred_y = sigmoid(x[:, :, :, :, 1]) + pred_x = x[:, :, :, :, 0] + pred_y = x[:, :, :, :, 1] pred_w = x[:, :, :, :, 2] pred_h = x[:, :, :, :, 3] - pred_conf = sigmoid(x[:, :, :, :, 4]) - pred_cls = sigmoid(x[:, :, :, :, 5:]) + pred_conf = x[:, :, :, :, 4] + pred_cls = x[:, :, :, :, 5:] - tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( + tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( gtbox, gtlabel, attrs, x.shape[2]) + obj_weight = obj_mask * tweight obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) - loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum()) - loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) - loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) - loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) - loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) - loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask, - noobj_mask) - loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, - obj_mask_expand) + box_f = an_num * h * w + class_f = an_num * h * w * class_num + loss_x = sce(pred_x, tx, obj_weight, box_f) + loss_y = sce(pred_y, ty, obj_weight, box_f) + loss_w = mse(pred_w, tw, obj_weight, box_f) + loss_h = mse(pred_h, th, obj_weight, box_f) + loss_conf_target = sce(pred_conf, tconf, obj_mask, box_f) + loss_conf_notarget = sce(pred_conf, tconf, noobj_mask, box_f) + loss_class = sce(pred_cls, tcls, obj_mask_expand, class_f) return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + attrs['loss_weight_wh'] * (loss_w + loss_h) \ @@ -164,7 +160,7 @@ class TestYolov3LossOp(OpTest): self.loss_weight_class = 1.0 self.initTestCase() self.op_type = 'yolov3_loss' - x = np.random.random(size=self.x_shape).astype('float32') + x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]).astype('int32') @@ -173,6 +169,7 @@ class TestYolov3LossOp(OpTest): "anchors": self.anchors, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, + "input_size": self.input_size, "loss_weight_xy": self.loss_weight_xy, "loss_weight_wh": self.loss_weight_wh, "loss_weight_conf_target": self.loss_weight_conf_target, @@ -196,18 +193,19 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.06) + max_relative_error=0.3) def initTestCase(self): self.anchors = [10, 13, 12, 12] self.class_num = 10 - self.ignore_thresh = 0.5 + self.ignore_thresh = 0.7 + self.input_size = 416 self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) self.gtbox_shape = (5, 10, 4) - self.loss_weight_xy = 2.5 + self.loss_weight_xy = 1.4 self.loss_weight_wh = 0.8 - self.loss_weight_conf_target = 1.5 - self.loss_weight_conf_notarget = 0.5 + self.loss_weight_conf_target = 1.1 + self.loss_weight_conf_notarget = 0.9 self.loss_weight_class = 1.2 From 3841983aa01dbb633e1d40b84f046ddfbf41beb8 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 7 Dec 2018 11:44:50 +0800 Subject: [PATCH 02/24] fix division error in mean process. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 4 +- paddle/fluid/operators/yolov3_loss_op.h | 263 ++++++++---------- .../paddle/fluid/tests/unittests/op_test.py | 2 + .../tests/unittests/test_yolov3_loss_op.py | 69 +++-- 4 files changed, 166 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 66d618de59..c76767dfdd 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -57,7 +57,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); - std::vector dim_out({1}); + std::vector dim_out({dim_x[0]}); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); } @@ -93,7 +93,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "box class id."); AddOutput("Loss", "The output yolov3 loss tensor, " - "This is a 1-D tensor with shape of [1]"); + "This is a 1-D tensor with shape of [N]"); AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index fac06b4204..837ea15601 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -33,99 +33,102 @@ static inline bool isZero(T x) { } template -static inline T CalcMSEWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, const T mf) { - int numel = static_cast(x.numel()); +static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y, + const Tensor& weight, const T loss_weight, + T* loss) { + int n = x.dims()[0]; + int stride = x.numel() / n; const T* x_data = x.data(); const T* y_data = y.data(); const T* weight_data = weight.data(); - T error_sum = 0.0; - for (int i = 0; i < numel; i++) { - T xi = x_data[i]; - T yi = y_data[i]; - T weighti = weight_data[i]; - error_sum += pow(yi - xi, 2) * weighti; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + loss[i] += pow(y_data[j] - x_data[j], 2) * weight_data[j] * loss_weight; + } + x_data += stride; + y_data += stride; + weight_data += stride; } - - return error_sum / mf; } template -static void CalcMSEGradWithWeight(Tensor* grad, const Tensor& x, - const Tensor& y, const Tensor& weight, - const T mf) { - int numel = static_cast(grad->numel()); +static void CalcMSEGradWithWeight(const T* loss_grad, Tensor* grad, + const Tensor& x, const Tensor& y, + const Tensor& weight) { + int n = x.dims()[0]; + int stride = x.numel() / n; T* grad_data = grad->data(); const T* x_data = x.data(); const T* y_data = y.data(); const T* weight_data = weight.data(); - for (int i = 0; i < numel; i++) { - grad_data[i] = 2.0 * weight_data[i] * (x_data[i] - y_data[i]) / mf; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + grad_data[j] = + 2.0 * weight_data[j] * (x_data[j] - y_data[j]) * loss_grad[i]; + } + grad_data += stride; + x_data += stride; + y_data += stride; + weight_data += stride; } } template -struct SigmoidCrossEntropyForward { - T operator()(const T& x, const T& label) const { - T term1 = (x > 0) ? x : 0; - T term2 = x * label; - T term3 = std::log(static_cast(1.0) + std::exp(-(std::abs(x)))); - return term1 - term2 + term3; - } -}; - -template -struct SigmoidCrossEntropyBackward { - T operator()(const T& x, const T& label) const { - T sigmoid_x = - static_cast(1.0) / (static_cast(1.0) + std::exp(-1.0 * x)); - return sigmoid_x - label; - } -}; - -template -static inline T CalcSCEWithWeight(const Tensor& x, const Tensor& labels, - const Tensor& weight, const T mf) { - int numel = x.numel(); +static inline void CalcSCEWithWeight(const Tensor& x, const Tensor& label, + const Tensor& weight, const T loss_weight, + T* loss) { + int n = x.dims()[0]; + int stride = x.numel() / n; const T* x_data = x.data(); - const T* labels_data = labels.data(); + const T* label_data = label.data(); const T* weight_data = weight.data(); - T loss = 0.0; - for (int i = 0; i < numel; i++) { - T xi = x_data[i]; - T labeli = labels_data[i]; - T weighti = weight_data[i]; - loss += ((xi > 0.0 ? xi : 0.0) - xi * labeli + - std::log(1.0 + std::exp(-1.0 * std::abs(xi)))) * - weighti; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + T term1 = (x_data[j] > 0) ? x_data[j] : 0; + T term2 = x_data[j] * label_data[j]; + T term3 = std::log(1.0 + std::exp(-std::abs(x_data[j]))); + loss[i] += (term1 - term2 + term3) * weight_data[j] * loss_weight; + } + x_data += stride; + label_data += stride; + weight_data += stride; } - return loss / mf; } template -static inline void CalcSCEGradWithWeight(Tensor* grad, const Tensor& x, - const Tensor& labels, - const Tensor& weight, const T mf) { - int numel = grad->numel(); +static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, + const Tensor& x, const Tensor& label, + const Tensor& weight) { + int n = x.dims()[0]; + int stride = x.numel() / n; T* grad_data = grad->data(); const T* x_data = x.data(); - const T* labels_data = labels.data(); + const T* label_data = label.data(); const T* weight_data = weight.data(); - for (int i = 0; i < numel; i++) { - grad_data[i] = (1.0 / (1.0 + std::exp(-1.0 * x_data[i])) - labels_data[i]) * - weight_data[i] / mf; + // LOG(ERROR) << "SCE grad start"; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * + weight_data[j] * loss_grad[i]; + // if (j == 18) LOG(ERROR) << x_data[j] << " " << label_data[j] << " " << + // weight_data[j] << " " << loss_grad[i]; + } + grad_data += stride; + x_data += stride; + label_data += stride; + weight_data += stride; } } template -static void CalcPredResult(const Tensor& input, Tensor* pred_conf, - Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, - Tensor* pred_w, Tensor* pred_h, const int anchor_num, - const int class_num) { +static void SplitPredResult(const Tensor& input, Tensor* pred_conf, + Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, + Tensor* pred_w, Tensor* pred_h, + const int anchor_num, const int class_num) { const int n = input.dims()[0]; const int h = input.dims()[2]; const int w = input.dims()[3]; @@ -255,39 +258,20 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, } } -static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, - const Tensor& obj_mask) { - const int n = obj_mask_expand->dims()[0]; - const int an_num = obj_mask_expand->dims()[1]; - const int h = obj_mask_expand->dims()[2]; - const int w = obj_mask_expand->dims()[3]; - const int class_num = obj_mask_expand->dims()[4]; - auto obj_mask_expand_t = EigenTensor::From(*obj_mask_expand); - auto obj_mask_t = EigenTensor::From(obj_mask); - - obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) - .broadcast(Array5(1, 1, 1, 1, class_num)); -} - template static void AddAllGradToInputGrad( - Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y, - const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, - const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, - const Tensor& grad_conf_target, const Tensor& grad_conf_notarget, - const Tensor& grad_class, const int class_num, const float loss_weight_xy, - const float loss_weight_wh, const float loss_weight_conf_target, - const float loss_weight_conf_notarget, const float loss_weight_class) { - const int n = pred_x.dims()[0]; - const int an_num = pred_x.dims()[1]; - const int h = pred_x.dims()[2]; - const int w = pred_x.dims()[3]; + Tensor* grad, const Tensor& grad_x, const Tensor& grad_y, + const Tensor& grad_w, const Tensor& grad_h, const Tensor& grad_conf_target, + const Tensor& grad_conf_notarget, const Tensor& grad_class, + const int class_num, const float loss_weight_xy, const float loss_weight_wh, + const float loss_weight_conf_target, const float loss_weight_conf_notarget, + const float loss_weight_class) { + const int n = grad_x.dims()[0]; + const int an_num = grad_x.dims()[1]; + const int h = grad_x.dims()[2]; + const int w = grad_x.dims()[3]; const int attr_num = class_num + 5; auto grad_t = EigenTensor::From(*grad).setConstant(0.0); - auto pred_x_t = EigenTensor::From(pred_x); - auto pred_y_t = EigenTensor::From(pred_y); - auto pred_conf_t = EigenTensor::From(pred_conf); - auto pred_class_t = EigenTensor::From(pred_class); auto grad_x_t = EigenTensor::From(grad_x); auto grad_y_t = EigenTensor::From(grad_y); auto grad_w_t = EigenTensor::From(grad_w); @@ -300,23 +284,21 @@ static void AddAllGradToInputGrad( for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - grad_t(i, j * attr_num, k, l) = - grad_x_t(i, j, k, l) * loss * loss_weight_xy; + grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * loss_weight_xy; grad_t(i, j * attr_num + 1, k, l) = - grad_y_t(i, j, k, l) * loss * loss_weight_xy; + grad_y_t(i, j, k, l) * loss_weight_xy; grad_t(i, j * attr_num + 2, k, l) = - grad_w_t(i, j, k, l) * loss * loss_weight_wh; + grad_w_t(i, j, k, l) * loss_weight_wh; grad_t(i, j * attr_num + 3, k, l) = - grad_h_t(i, j, k, l) * loss * loss_weight_wh; + grad_h_t(i, j, k, l) * loss_weight_wh; grad_t(i, j * attr_num + 4, k, l) = - grad_conf_target_t(i, j, k, l) * loss * loss_weight_conf_target; + grad_conf_target_t(i, j, k, l) * loss_weight_conf_target; grad_t(i, j * attr_num + 4, k, l) += - grad_conf_notarget_t(i, j, k, l) * loss * - loss_weight_conf_notarget; + grad_conf_notarget_t(i, j, k, l) * loss_weight_conf_notarget; for (int c = 0; c < class_num; c++) { grad_t(i, j * attr_num + 5 + c, k, l) = - grad_class_t(i, j, k, l, c) * loss * loss_weight_class; + grad_class_t(i, j, k, l, c) * loss_weight_class; } } } @@ -356,8 +338,8 @@ class Yolov3LossKernel : public framework::OpKernel { pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); + SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; @@ -388,25 +370,24 @@ class Yolov3LossKernel : public framework::OpKernel { obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) .broadcast(Array5(1, 1, 1, 1, class_num)); - T box_f = static_cast(an_num * h * w); - T class_f = static_cast(an_num * h * w * class_num); - T loss_x = CalcSCEWithWeight(pred_x, tx, obj_weight, box_f); - T loss_y = CalcSCEWithWeight(pred_y, ty, obj_weight, box_f); - T loss_w = CalcMSEWithWeight(pred_w, tw, obj_weight, box_f); - T loss_h = CalcMSEWithWeight(pred_h, th, obj_weight, box_f); - T loss_conf_target = - CalcSCEWithWeight(pred_conf, tconf, obj_mask, box_f); - T loss_conf_notarget = - CalcSCEWithWeight(pred_conf, tconf, noobj_mask, box_f); - T loss_class = - CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, class_f); - - auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); - loss_data[0] = loss_weight_xy * (loss_x + loss_y) + - loss_weight_wh * (loss_w + loss_h) + - loss_weight_conf_target * loss_conf_target + - loss_weight_conf_notarget * loss_conf_notarget + - loss_weight_class * loss_class; + T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); + memset(loss_data, 0, n * sizeof(T)); + CalcSCEWithWeight(pred_x, tx, obj_weight, loss_weight_xy, loss_data); + CalcSCEWithWeight(pred_y, ty, obj_weight, loss_weight_xy, loss_data); + CalcMSEWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); + CalcMSEWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); + CalcSCEWithWeight(pred_conf, tconf, obj_mask, loss_weight_conf_target, + loss_data); + CalcSCEWithWeight(pred_conf, tconf, noobj_mask, + loss_weight_conf_notarget, loss_data); + CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, loss_weight_class, + loss_data); + + // loss_data[0] = (loss_weight_xy * (loss_x + loss_y) + + // loss_weight_wh * (loss_w + loss_h) + + // loss_weight_conf_target * loss_conf_target + + // loss_weight_conf_notarget * loss_conf_notarget + + // loss_weight_class * loss_class) / n; } }; @@ -421,8 +402,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* output_grad = ctx.Input(framework::GradVarName("Loss")); - const T loss = output_grad->data()[0]; + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + const T* loss_grad_data = loss_grad->data(); int input_size = ctx.Attr("input_size"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); @@ -445,8 +426,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); + SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; @@ -470,6 +451,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto tweight_t = EigenTensor::From(tweight); obj_weight_t = obj_mask_t * tweight_t; + // LOG(ERROR) << obj_mask_t; + Tensor obj_mask_expand; obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); @@ -486,25 +469,23 @@ class Yolov3LossGradKernel : public framework::OpKernel { grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - T box_f = static_cast(an_num * h * w); - T class_f = static_cast(an_num * h * w * class_num); - CalcSCEGradWithWeight(&grad_x, pred_x, tx, obj_weight, box_f); - CalcSCEGradWithWeight(&grad_y, pred_y, ty, obj_weight, box_f); - CalcMSEGradWithWeight(&grad_w, pred_w, tw, obj_weight, box_f); - CalcMSEGradWithWeight(&grad_h, pred_h, th, obj_weight, box_f); - CalcSCEGradWithWeight(&grad_conf_target, pred_conf, tconf, obj_mask, - box_f); - CalcSCEGradWithWeight(&grad_conf_notarget, pred_conf, tconf, noobj_mask, - box_f); - CalcSCEGradWithWeight(&grad_class, pred_class, tclass, obj_mask_expand, - class_f); + CalcSCEGradWithWeight(loss_grad_data, &grad_x, pred_x, tx, obj_weight); + CalcSCEGradWithWeight(loss_grad_data, &grad_y, pred_y, ty, obj_weight); + CalcMSEGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, obj_weight); + CalcMSEGradWithWeight(loss_grad_data, &grad_h, pred_h, th, obj_weight); + CalcSCEGradWithWeight(loss_grad_data, &grad_conf_target, pred_conf, + tconf, obj_mask); + CalcSCEGradWithWeight(loss_grad_data, &grad_conf_notarget, pred_conf, + tconf, noobj_mask); + CalcSCEGradWithWeight(loss_grad_data, &grad_class, pred_class, tclass, + obj_mask_expand); input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - AddAllGradToInputGrad( - input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, - grad_w, grad_h, grad_conf_target, grad_conf_notarget, grad_class, - class_num, loss_weight_xy, loss_weight_wh, loss_weight_conf_target, - loss_weight_conf_notarget, loss_weight_class); + AddAllGradToInputGrad(input_grad, grad_x, grad_y, grad_w, grad_h, + grad_conf_target, grad_conf_notarget, grad_class, + class_num, loss_weight_xy, loss_weight_wh, + loss_weight_conf_target, loss_weight_conf_notarget, + loss_weight_class); } }; diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 0fe836683b..9cf398f18f 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -470,6 +470,8 @@ class OpTest(unittest.TestCase): ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) + # print(numeric_grads[0][0, 4, :, :]) + # print(analytic_grads[0][0, 4, :, :]) self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, max_relative_error, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 07e7155bbf..26367f213b 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -23,15 +23,23 @@ from op_test import OpTest from paddle.fluid import core -def mse(x, y, weight, num): - return ((y - x)**2 * weight).sum() / num - - -def sce(x, label, weight, num): +def mse(x, y, weight): + n = x.shape[0] + x = x.reshape((n, -1)) + y = y.reshape((n, -1)) + weight = weight.reshape((n, -1)) + return ((y - x)**2 * weight).sum(axis=1) + + +def sce(x, label, weight): + n = x.shape[0] + x = x.reshape((n, -1)) + label = label.reshape((n, -1)) + weight = weight.reshape((n, -1)) sigmoid_x = expit(x) term1 = label * np.log(sigmoid_x) term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) - return ((-term1 - term2) * weight).sum() / num + return ((-term1 - term2) * weight).sum(axis=1) def box_iou(box1, box2): @@ -131,18 +139,24 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( gtbox, gtlabel, attrs, x.shape[2]) + # print("obj_mask: ", obj_mask[0, 0, :, :]) + # print("noobj_mask: ", noobj_mask[0, 0, :, :]) obj_weight = obj_mask * tweight obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) - box_f = an_num * h * w - class_f = an_num * h * w * class_num - loss_x = sce(pred_x, tx, obj_weight, box_f) - loss_y = sce(pred_y, ty, obj_weight, box_f) - loss_w = mse(pred_w, tw, obj_weight, box_f) - loss_h = mse(pred_h, th, obj_weight, box_f) - loss_conf_target = sce(pred_conf, tconf, obj_mask, box_f) - loss_conf_notarget = sce(pred_conf, tconf, noobj_mask, box_f) - loss_class = sce(pred_cls, tcls, obj_mask_expand, class_f) + loss_x = sce(pred_x, tx, obj_weight) + loss_y = sce(pred_y, ty, obj_weight) + loss_w = mse(pred_w, tw, obj_weight) + loss_h = mse(pred_h, th, obj_weight) + loss_conf_target = sce(pred_conf, tconf, obj_mask) + loss_conf_notarget = sce(pred_conf, tconf, noobj_mask) + loss_class = sce(pred_cls, tcls, obj_mask_expand) + + # print("loss_xy: ", loss_x + loss_y) + # print("loss_wh: ", loss_w + loss_h) + # print("loss_conf_target: ", loss_conf_target) + # print("loss_conf_notarget: ", loss_conf_notarget) + # print("loss_class: ", loss_class) return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + attrs['loss_weight_wh'] * (loss_w + loss_h) \ @@ -178,10 +192,7 @@ class TestYolov3LossOp(OpTest): } self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} - self.outputs = { - 'Loss': np.array( - [YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32') - } + self.outputs = {'Loss': YoloV3Loss(x, gtbox, gtlabel, self.attrs)} def test_check_output(self): place = core.CPUPlace() @@ -193,20 +204,20 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.3) + max_relative_error=0.31) def initTestCase(self): - self.anchors = [10, 13, 12, 12] - self.class_num = 10 - self.ignore_thresh = 0.7 + self.anchors = [12, 12] + self.class_num = 5 + self.ignore_thresh = 0.3 self.input_size = 416 - self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) - self.gtbox_shape = (5, 10, 4) - self.loss_weight_xy = 1.4 + self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 5, 4) + self.loss_weight_xy = 1.2 self.loss_weight_wh = 0.8 - self.loss_weight_conf_target = 1.1 - self.loss_weight_conf_notarget = 0.9 - self.loss_weight_class = 1.2 + self.loss_weight_conf_target = 2.0 + self.loss_weight_conf_notarget = 1.0 + self.loss_weight_class = 1.5 if __name__ == "__main__": From c0fa8d2eec4d6986c4b224a9183207160ea44107 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 10 Dec 2018 20:14:57 +0800 Subject: [PATCH 03/24] use L1Loss for w, h. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 53 +++++++++++++++++-- .../tests/unittests/test_yolov3_loss_op.py | 12 ++++- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 837ea15601..4661747261 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -32,6 +32,49 @@ static inline bool isZero(T x) { return fabs(x) < 1e-6; } +template +static inline void CalcL1LossWithWeight(const Tensor& x, const Tensor& y, + const Tensor& weight, + const T loss_weight, T* loss) { + int n = x.dims()[0]; + int stride = x.numel() / n; + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + loss[i] += fabs(y_data[j] - x_data[j]) * weight_data[j] * loss_weight; + } + x_data += stride; + y_data += stride; + weight_data += stride; + } +} + +template +static void CalcL1LossGradWithWeight(const T* loss_grad, Tensor* grad, + const Tensor& x, const Tensor& y, + const Tensor& weight) { + int n = x.dims()[0]; + int stride = x.numel() / n; + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + grad_data[j] = weight_data[j] * loss_grad[i]; + if (x_data[j] < y_data[j]) grad_data[j] *= -1.0; + } + grad_data += stride; + x_data += stride; + y_data += stride; + weight_data += stride; + } +} + template static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y, const Tensor& weight, const T loss_weight, @@ -374,8 +417,8 @@ class Yolov3LossKernel : public framework::OpKernel { memset(loss_data, 0, n * sizeof(T)); CalcSCEWithWeight(pred_x, tx, obj_weight, loss_weight_xy, loss_data); CalcSCEWithWeight(pred_y, ty, obj_weight, loss_weight_xy, loss_data); - CalcMSEWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); - CalcMSEWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); + CalcL1LossWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); + CalcL1LossWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); CalcSCEWithWeight(pred_conf, tconf, obj_mask, loss_weight_conf_target, loss_data); CalcSCEWithWeight(pred_conf, tconf, noobj_mask, @@ -471,8 +514,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); CalcSCEGradWithWeight(loss_grad_data, &grad_x, pred_x, tx, obj_weight); CalcSCEGradWithWeight(loss_grad_data, &grad_y, pred_y, ty, obj_weight); - CalcMSEGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, obj_weight); - CalcMSEGradWithWeight(loss_grad_data, &grad_h, pred_h, th, obj_weight); + CalcL1LossGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, + obj_weight); + CalcL1LossGradWithWeight(loss_grad_data, &grad_h, pred_h, th, + obj_weight); CalcSCEGradWithWeight(loss_grad_data, &grad_conf_target, pred_conf, tconf, obj_mask); CalcSCEGradWithWeight(loss_grad_data, &grad_conf_notarget, pred_conf, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 26367f213b..e218031286 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -23,6 +23,14 @@ from op_test import OpTest from paddle.fluid import core +def l1loss(x, y, weight): + n = x.shape[0] + x = x.reshape((n, -1)) + y = y.reshape((n, -1)) + weight = weight.reshape((n, -1)) + return (np.abs(y - x) * weight).sum(axis=1) + + def mse(x, y, weight): n = x.shape[0] x = x.reshape((n, -1)) @@ -146,8 +154,8 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) loss_x = sce(pred_x, tx, obj_weight) loss_y = sce(pred_y, ty, obj_weight) - loss_w = mse(pred_w, tw, obj_weight) - loss_h = mse(pred_h, th, obj_weight) + loss_w = l1loss(pred_w, tw, obj_weight) + loss_h = l1loss(pred_h, th, obj_weight) loss_conf_target = sce(pred_conf, tconf, obj_mask) loss_conf_notarget = sce(pred_conf, tconf, noobj_mask) loss_class = sce(pred_cls, tcls, obj_mask_expand) From 2fbfef2ec9683ac18903ca8cf7cb69c5389ba3ba Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 13 Dec 2018 19:15:52 +0800 Subject: [PATCH 04/24] fix no box expression. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 4661747261..d0064a8190 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -152,13 +152,10 @@ static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, const T* label_data = label.data(); const T* weight_data = weight.data(); - // LOG(ERROR) << "SCE grad start"; for (int i = 0; i < n; i++) { for (int j = 0; j < stride; j++) { grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * weight_data[j] * loss_grad[i]; - // if (j == 18) LOG(ERROR) << x_data[j] << " " << label_data[j] << " " << - // weight_data[j] << " " << loss_grad[i]; } grad_data += stride; x_data += stride; @@ -258,8 +255,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_box_t(i, j, 0)) && isZero(gt_box_t(i, j, 1)) && - isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + if (isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { continue; } @@ -425,12 +421,6 @@ class Yolov3LossKernel : public framework::OpKernel { loss_weight_conf_notarget, loss_data); CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, loss_weight_class, loss_data); - - // loss_data[0] = (loss_weight_xy * (loss_x + loss_y) + - // loss_weight_wh * (loss_w + loss_h) + - // loss_weight_conf_target * loss_conf_target + - // loss_weight_conf_notarget * loss_conf_notarget + - // loss_weight_class * loss_class) / n; } }; @@ -494,8 +484,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto tweight_t = EigenTensor::From(tweight); obj_weight_t = obj_mask_t * tweight_t; - // LOG(ERROR) << obj_mask_t; - Tensor obj_mask_expand; obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); From 0c4acc83050fb83860884ea02ac241a5ddd6800e Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sun, 16 Dec 2018 17:50:41 +0800 Subject: [PATCH 05/24] imporve yolo loss implement. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 17 +- paddle/fluid/operators/yolov3_loss_op.h | 432 ++++++++++-------- python/paddle/fluid/layers/detection.py | 34 +- .../paddle/fluid/tests/unittests/op_test.py | 2 - .../tests/unittests/test_yolov3_loss_op.py | 49 +- 5 files changed, 267 insertions(+), 267 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index c76767dfdd..3bd0db8b59 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -34,11 +34,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto anchors = ctx->Attrs().Get>("anchors"); + int anchor_num = anchors.size() / 2; auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); - PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), + PADDLE_ENFORCE_EQ(dim_x[1], anchor_num * (5 + class_num), "Input(X) dim[1] should be equal to (anchor_number * (5 " "+ class_num))."); PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, @@ -105,20 +106,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(406); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); - AddAttr("loss_weight_xy", "The weight of x, y location loss.") - .SetDefault(1.0); - AddAttr("loss_weight_wh", "The weight of w, h location loss.") - .SetDefault(1.0); - AddAttr( - "loss_weight_conf_target", - "The weight of confidence score loss in locations with target object.") - .SetDefault(1.0); - AddAttr("loss_weight_conf_notarget", - "The weight of confidence score loss in locations without " - "target object.") - .SetDefault(1.0); - AddAttr("loss_weight_class", "The weight of classification loss.") - .SetDefault(1.0); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index d0064a8190..5de5b4efc7 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -164,48 +164,50 @@ static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, } } -template -static void SplitPredResult(const Tensor& input, Tensor* pred_conf, - Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, - Tensor* pred_w, Tensor* pred_h, - const int anchor_num, const int class_num) { - const int n = input.dims()[0]; - const int h = input.dims()[2]; - const int w = input.dims()[3]; - const int box_attr_num = 5 + class_num; - - auto input_t = EigenTensor::From(input); - auto pred_conf_t = EigenTensor::From(*pred_conf); - auto pred_class_t = EigenTensor::From(*pred_class); - auto pred_x_t = EigenTensor::From(*pred_x); - auto pred_y_t = EigenTensor::From(*pred_y); - auto pred_w_t = EigenTensor::From(*pred_w); - auto pred_h_t = EigenTensor::From(*pred_h); - - for (int i = 0; i < n; i++) { - for (int an_idx = 0; an_idx < anchor_num; an_idx++) { - for (int j = 0; j < h; j++) { - for (int k = 0; k < w; k++) { - pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, k); - pred_y_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 1, j, k); - pred_w_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 2, j, k); - pred_h_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 3, j, k); - - pred_conf_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 4, j, k); - - for (int c = 0; c < class_num; c++) { - pred_class_t(i, an_idx, j, k, c) = - input_t(i, box_attr_num * an_idx + 5 + c, j, k); - } - } - } - } - } -} +// template +// static void SplitPredResult(const Tensor& input, Tensor* pred_conf, +// Tensor* pred_class, Tensor* pred_x, Tensor* +// pred_y, +// Tensor* pred_w, Tensor* pred_h, +// const int anchor_num, const int class_num) { +// const int n = input.dims()[0]; +// const int h = input.dims()[2]; +// const int w = input.dims()[3]; +// const int box_attr_num = 5 + class_num; +// +// auto input_t = EigenTensor::From(input); +// auto pred_conf_t = EigenTensor::From(*pred_conf); +// auto pred_class_t = EigenTensor::From(*pred_class); +// auto pred_x_t = EigenTensor::From(*pred_x); +// auto pred_y_t = EigenTensor::From(*pred_y); +// auto pred_w_t = EigenTensor::From(*pred_w); +// auto pred_h_t = EigenTensor::From(*pred_h); +// +// for (int i = 0; i < n; i++) { +// for (int an_idx = 0; an_idx < anchor_num; an_idx++) { +// for (int j = 0; j < h; j++) { +// for (int k = 0; k < w; k++) { +// pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, +// k); +// pred_y_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 1, j, k); +// pred_w_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 2, j, k); +// pred_h_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 3, j, k); +// +// pred_conf_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 4, j, k); +// +// for (int c = 0; c < class_num; c++) { +// pred_class_t(i, an_idx, j, k, c) = +// input_t(i, box_attr_num * an_idx + 5 + c, j, k); +// } +// } +// } +// } +// } +// } template static T CalcBoxIoU(std::vector box1, std::vector box2) { @@ -235,7 +237,7 @@ template static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, const float ignore_thresh, std::vector anchors, const int input_size, const int grid_size, - Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, + Tensor* conf_mask, Tensor* obj_mask, Tensor* tx, Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight, Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; @@ -243,8 +245,8 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, const int anchor_num = anchors.size() / 2; auto gt_box_t = EigenTensor::From(gt_box); auto gt_label_t = EigenTensor::From(gt_label); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); - auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); + auto conf_mask_t = EigenTensor::From(*conf_mask).setConstant(1.0); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0.0); auto tx_t = EigenTensor::From(*tx).setConstant(0.0); auto ty_t = EigenTensor::From(*ty).setConstant(0.0); auto tw_t = EigenTensor::From(*tw).setConstant(0.0); @@ -280,11 +282,11 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - noobj_mask_t(i, an_idx, gj, gi) = static_cast(0.0); + conf_mask_t(i, an_idx, gj, gi) = static_cast(0.0); } } + conf_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - noobj_mask_t(i, best_an_index, gj, gi) = static_cast(0.0); tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); @@ -298,53 +300,194 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, } template -static void AddAllGradToInputGrad( - Tensor* grad, const Tensor& grad_x, const Tensor& grad_y, - const Tensor& grad_w, const Tensor& grad_h, const Tensor& grad_conf_target, - const Tensor& grad_conf_notarget, const Tensor& grad_class, - const int class_num, const float loss_weight_xy, const float loss_weight_wh, - const float loss_weight_conf_target, const float loss_weight_conf_notarget, - const float loss_weight_class) { - const int n = grad_x.dims()[0]; - const int an_num = grad_x.dims()[1]; - const int h = grad_x.dims()[2]; - const int w = grad_x.dims()[3]; - const int attr_num = class_num + 5; - auto grad_t = EigenTensor::From(*grad).setConstant(0.0); - auto grad_x_t = EigenTensor::From(grad_x); - auto grad_y_t = EigenTensor::From(grad_y); - auto grad_w_t = EigenTensor::From(grad_w); - auto grad_h_t = EigenTensor::From(grad_h); - auto grad_conf_target_t = EigenTensor::From(grad_conf_target); - auto grad_conf_notarget_t = EigenTensor::From(grad_conf_notarget); - auto grad_class_t = EigenTensor::From(grad_class); +static T SCE(T x, T label) { + return (x > 0 ? x : 0.0) - x * label + std::log(1.0 + std::exp(-std::abs(x))); +} + +template +static T L1Loss(T x, T y) { + return std::abs(y - x); +} + +template +static T SCEGrad(T x, T label) { + return 1.0 / (1.0 + std::exp(-x)) - label; +} + +template +static T L1LossGrad(T x, T y) { + return x > y ? 1.0 : -1.0; +} + +template +static void CalcSCE(T* loss_data, const T* input, const T* target, + const T* weight, const T* mask, const int n, + const int an_num, const int grid_num, const int class_num, + const int num) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < grid_num; k++) { + int sub_idx = k * num; + for (int l = 0; l < num; l++) { + loss_data[i] += SCE(input[l * grid_num + k], target[sub_idx + l]) * + weight[k] * mask[k]; + } + } + input += (class_num + 5) * grid_num; + target += grid_num * num; + weight += grid_num; + mask += grid_num; + } + } +} +template +static void CalcSCEGrad(T* input_grad, const T* loss_grad, const T* input, + const T* target, const T* weight, const T* mask, + const int n, const int an_num, const int grid_num, + const int class_num, const int num) { for (int i = 0; i < n; i++) { for (int j = 0; j < an_num; j++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * loss_weight_xy; - grad_t(i, j * attr_num + 1, k, l) = - grad_y_t(i, j, k, l) * loss_weight_xy; - grad_t(i, j * attr_num + 2, k, l) = - grad_w_t(i, j, k, l) * loss_weight_wh; - grad_t(i, j * attr_num + 3, k, l) = - grad_h_t(i, j, k, l) * loss_weight_wh; - grad_t(i, j * attr_num + 4, k, l) = - grad_conf_target_t(i, j, k, l) * loss_weight_conf_target; - grad_t(i, j * attr_num + 4, k, l) += - grad_conf_notarget_t(i, j, k, l) * loss_weight_conf_notarget; - - for (int c = 0; c < class_num; c++) { - grad_t(i, j * attr_num + 5 + c, k, l) = - grad_class_t(i, j, k, l, c) * loss_weight_class; - } + for (int k = 0; k < grid_num; k++) { + int sub_idx = k * num; + for (int l = 0; l < num; l++) { + input_grad[l * grid_num + k] = + SCEGrad(input[l * grid_num + k], target[sub_idx + l]) * + weight[k] * mask[k] * loss_grad[i]; } } + input_grad += (class_num + 5) * grid_num; + input += (class_num + 5) * grid_num; + target += grid_num * num; + weight += grid_num; + mask += grid_num; + } + } +} + +template +static void CalcL1Loss(T* loss_data, const T* input, const T* target, + const T* weight, const T* mask, const int n, + const int an_num, const int grid_num, + const int class_num) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < grid_num; k++) { + loss_data[i] += L1Loss(input[k], target[k]) * weight[k] * mask[k]; + } + input += (class_num + 5) * grid_num; + target += grid_num; + weight += grid_num; + mask += grid_num; + } + } +} + +template +static void CalcL1LossGrad(T* input_grad, const T* loss_grad, const T* input, + const T* target, const T* weight, const T* mask, + const int n, const int an_num, const int grid_num, + const int class_num) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < grid_num; k++) { + input_grad[k] = L1LossGrad(input[k], target[k]) * weight[k] * + mask[k] * loss_grad[i]; + } + input_grad += (class_num + 5) * grid_num; + input += (class_num + 5) * grid_num; + target += grid_num; + weight += grid_num; + mask += grid_num; } } } +template +static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, + const Tensor& ty, const Tensor& tw, const Tensor& th, + const Tensor& tweight, const Tensor& tconf, + const Tensor& tclass, const Tensor& conf_mask, + const Tensor& obj_mask) { + const T* input_data = input.data(); + const T* tx_data = tx.data(); + const T* ty_data = ty.data(); + const T* tw_data = tw.data(); + const T* th_data = th.data(); + const T* tweight_data = tweight.data(); + const T* tconf_data = tconf.data(); + const T* tclass_data = tclass.data(); + const T* conf_mask_data = conf_mask.data(); + const T* obj_mask_data = obj_mask.data(); + + const int n = tclass.dims()[0]; + const int an_num = tclass.dims()[1]; + const int h = tclass.dims()[2]; + const int w = tclass.dims()[3]; + const int class_num = tclass.dims()[4]; + const int grid_num = h * w; + + CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, + an_num, grid_num, class_num, 1); + CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num, 1); + CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, + conf_mask_data, n, an_num, grid_num, class_num, 1); + CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, + obj_mask_data, n, an_num, grid_num, class_num, class_num); +} + +template +static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, + const Tensor& input, const Tensor& tx, + const Tensor& ty, const Tensor& tw, + const Tensor& th, const Tensor& tweight, + const Tensor& tconf, const Tensor& tclass, + const Tensor& conf_mask, + const Tensor& obj_mask) { + const T* loss_grad_data = loss_grad.data(); + const T* input_data = input.data(); + const T* tx_data = tx.data(); + const T* ty_data = ty.data(); + const T* tw_data = tw.data(); + const T* th_data = th.data(); + const T* tweight_data = tweight.data(); + const T* tconf_data = tconf.data(); + const T* tclass_data = tclass.data(); + const T* conf_mask_data = conf_mask.data(); + const T* obj_mask_data = obj_mask.data(); + + const int n = tclass.dims()[0]; + const int an_num = tclass.dims()[1]; + const int h = tclass.dims()[2]; + const int w = tclass.dims()[3]; + const int class_num = tclass.dims()[4]; + const int grid_num = h * w; + + CalcSCEGrad(input_grad_data, loss_grad_data, input_data, tx_data, + tweight_data, obj_mask_data, n, an_num, grid_num, class_num, + 1); + CalcSCEGrad(input_grad_data + grid_num, loss_grad_data, + input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, + an_num, grid_num, class_num, 1); + CalcL1LossGrad(input_grad_data + 2 * grid_num, loss_grad_data, + input_data + 2 * grid_num, tw_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcL1LossGrad(input_grad_data + 3 * grid_num, loss_grad_data, + input_data + 3 * grid_num, th_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcSCEGrad(input_grad_data + 4 * grid_num, loss_grad_data, + input_data + 4 * grid_num, tconf_data, conf_mask_data, + conf_mask_data, n, an_num, grid_num, class_num, 1); + CalcSCEGrad(input_grad_data + 5 * grid_num, loss_grad_data, + input_data + 5 * grid_num, tclass_data, obj_mask_data, + obj_mask_data, n, an_num, grid_num, class_num, class_num); +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -357,33 +500,16 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); - float loss_weight_xy = ctx.Attr("loss_weight_xy"); - float loss_weight_wh = ctx.Attr("loss_weight_wh"); - float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); - float loss_weight_conf_notarget = - ctx.Attr("loss_weight_conf_notarget"); - float loss_weight_class = ctx.Attr("loss_weight_class"); const int n = input->dims()[0]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - Tensor pred_x, pred_y, pred_w, pred_h; - Tensor pred_conf, pred_class; - pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); - - Tensor obj_mask, noobj_mask; + Tensor conf_mask, obj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; + conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); @@ -392,35 +518,13 @@ class Yolov3LossKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); - Tensor obj_weight; - obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - auto obj_weight_t = EigenTensor::From(obj_weight); - auto obj_mask_t = EigenTensor::From(obj_mask); - auto tweight_t = EigenTensor::From(tweight); - obj_weight_t = obj_mask_t * tweight_t; - - Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); - obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) - .broadcast(Array5(1, 1, 1, 1, class_num)); - T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, n * sizeof(T)); - CalcSCEWithWeight(pred_x, tx, obj_weight, loss_weight_xy, loss_data); - CalcSCEWithWeight(pred_y, ty, obj_weight, loss_weight_xy, loss_data); - CalcL1LossWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); - CalcL1LossWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); - CalcSCEWithWeight(pred_conf, tconf, obj_mask, loss_weight_conf_target, - loss_data); - CalcSCEWithWeight(pred_conf, tconf, noobj_mask, - loss_weight_conf_notarget, loss_data); - CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, loss_weight_class, - loss_data); + CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, tclass, + conf_mask, obj_mask); } }; @@ -436,14 +540,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { float ignore_thresh = ctx.Attr("ignore_thresh"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); - const T* loss_grad_data = loss_grad->data(); int input_size = ctx.Attr("input_size"); - float loss_weight_xy = ctx.Attr("loss_weight_xy"); - float loss_weight_wh = ctx.Attr("loss_weight_wh"); - float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); - float loss_weight_conf_notarget = - ctx.Attr("loss_weight_conf_notarget"); - float loss_weight_class = ctx.Attr("loss_weight_class"); const int n = input->dims()[0]; const int c = input->dims()[1]; @@ -451,21 +548,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - Tensor pred_x, pred_y, pred_w, pred_h; - Tensor pred_conf, pred_class; - pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); - - Tensor obj_mask, noobj_mask; + Tensor conf_mask, obj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; + conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); @@ -474,51 +560,13 @@ class Yolov3LossGradKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); - Tensor obj_weight; - obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - auto obj_weight_t = EigenTensor::From(obj_weight); - auto obj_mask_t = EigenTensor::From(obj_mask); - auto tweight_t = EigenTensor::From(tweight); - obj_weight_t = obj_mask_t * tweight_t; - - Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); - obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) - .broadcast(Array5(1, 1, 1, 1, class_num)); - - Tensor grad_x, grad_y, grad_w, grad_h; - Tensor grad_conf_target, grad_conf_notarget, grad_class; - grad_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcSCEGradWithWeight(loss_grad_data, &grad_x, pred_x, tx, obj_weight); - CalcSCEGradWithWeight(loss_grad_data, &grad_y, pred_y, ty, obj_weight); - CalcL1LossGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, - obj_weight); - CalcL1LossGradWithWeight(loss_grad_data, &grad_h, pred_h, th, - obj_weight); - CalcSCEGradWithWeight(loss_grad_data, &grad_conf_target, pred_conf, - tconf, obj_mask); - CalcSCEGradWithWeight(loss_grad_data, &grad_conf_notarget, pred_conf, - tconf, noobj_mask); - CalcSCEGradWithWeight(loss_grad_data, &grad_class, pred_class, tclass, - obj_mask_expand); - - input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - AddAllGradToInputGrad(input_grad, grad_x, grad_y, grad_w, grad_h, - grad_conf_target, grad_conf_notarget, grad_class, - class_num, loss_weight_xy, loss_weight_wh, - loss_weight_conf_target, loss_weight_conf_notarget, - loss_weight_class); + T* input_grad_data = + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, th, + tweight, tconf, tclass, conf_mask, obj_mask); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 5fb4588e0b..caa9b1c3d4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -416,11 +416,6 @@ def yolov3_loss(x, class_num, ignore_thresh, input_size, - loss_weight_xy=None, - loss_weight_wh=None, - loss_weight_conf_target=None, - loss_weight_conf_notarget=None, - loss_weight_class=None, name=None): """ ${comment} @@ -438,11 +433,6 @@ def yolov3_loss(x, class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} input_size (int): ${input_size_comment} - loss_weight_xy (float|None): ${loss_weight_xy_comment} - loss_weight_wh (float|None): ${loss_weight_wh_comment} - loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} - loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment} - loss_weight_class (float|None): ${loss_weight_class_comment} name (string): the name of yolov3 loss Returns: @@ -495,18 +485,18 @@ def yolov3_loss(x, "input_size": input_size, } - if loss_weight_xy is not None and isinstance(loss_weight_xy, float): - self.attrs['loss_weight_xy'] = loss_weight_xy - if loss_weight_wh is not None and isinstance(loss_weight_wh, float): - self.attrs['loss_weight_wh'] = loss_weight_wh - if loss_weight_conf_target is not None and isinstance( - loss_weight_conf_target, float): - self.attrs['loss_weight_conf_target'] = loss_weight_conf_target - if loss_weight_conf_notarget is not None and isinstance( - loss_weight_conf_notarget, float): - self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget - if loss_weight_class is not None and isinstance(loss_weight_class, float): - self.attrs['loss_weight_class'] = loss_weight_class + # if loss_weight_xy is not None and isinstance(loss_weight_xy, float): + # self.attrs['loss_weight_xy'] = loss_weight_xy + # if loss_weight_wh is not None and isinstance(loss_weight_wh, float): + # self.attrs['loss_weight_wh'] = loss_weight_wh + # if loss_weight_conf_target is not None and isinstance( + # loss_weight_conf_target, float): + # self.attrs['loss_weight_conf_target'] = loss_weight_conf_target + # if loss_weight_conf_notarget is not None and isinstance( + # loss_weight_conf_notarget, float): + # self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget + # if loss_weight_class is not None and isinstance(loss_weight_class, float): + # self.attrs['loss_weight_class'] = loss_weight_class helper.append_op( type='yolov3_loss', diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 9cf398f18f..0fe836683b 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -470,8 +470,6 @@ class OpTest(unittest.TestCase): ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) - # print(numeric_grads[0][0, 4, :, :]) - # print(analytic_grads[0][0, 4, :, :]) self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, max_relative_error, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index e218031286..cf7e2c5289 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -80,8 +80,8 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): class_num = attrs["class_num"] input_size = attrs["input_size"] an_num = len(anchors) // 2 + conf_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') @@ -114,10 +114,10 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): max_iou = iou best_an_index = k if iou > ignore_thresh: - noobj_mask[i, best_an_index, gj, gi] = 0 + conf_mask[i, best_an_index, gj, gi] = 0 + conf_mask[i, best_an_index, gj, gi] = 1 obj_mask[i, best_an_index, gj, gi] = 1 - noobj_mask[i, best_an_index, gj, gi] = 0 tx[i, best_an_index, gj, gi] = gx - gi ty[i, best_an_index, gj, gi] = gy - gj tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * @@ -129,7 +129,7 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): tconf[i, best_an_index, gj, gi] = 1 tcls[i, best_an_index, gj, gi, gt_label] = 1 - return (tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask) + return (tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask) def YoloV3Loss(x, gtbox, gtlabel, attrs): @@ -144,11 +144,9 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): pred_conf = x[:, :, :, :, 4] pred_cls = x[:, :, :, :, 5:] - tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( + tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask = build_target( gtbox, gtlabel, attrs, x.shape[2]) - # print("obj_mask: ", obj_mask[0, 0, :, :]) - # print("noobj_mask: ", noobj_mask[0, 0, :, :]) obj_weight = obj_mask * tweight obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) @@ -156,30 +154,19 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): loss_y = sce(pred_y, ty, obj_weight) loss_w = l1loss(pred_w, tw, obj_weight) loss_h = l1loss(pred_h, th, obj_weight) - loss_conf_target = sce(pred_conf, tconf, obj_mask) - loss_conf_notarget = sce(pred_conf, tconf, noobj_mask) + loss_obj = sce(pred_conf, tconf, conf_mask) loss_class = sce(pred_cls, tcls, obj_mask_expand) - # print("loss_xy: ", loss_x + loss_y) - # print("loss_wh: ", loss_w + loss_h) - # print("loss_conf_target: ", loss_conf_target) - # print("loss_conf_notarget: ", loss_conf_notarget) - # print("loss_class: ", loss_class) + # print("python loss_xy: ", loss_x + loss_y) + # print("python loss_wh: ", loss_w + loss_h) + # print("python loss_obj: ", loss_obj) + # print("python loss_class: ", loss_class) - return attrs['loss_weight_xy'] * (loss_x + loss_y) \ - + attrs['loss_weight_wh'] * (loss_w + loss_h) \ - + attrs['loss_weight_conf_target'] * loss_conf_target \ - + attrs['loss_weight_conf_notarget'] * loss_conf_notarget \ - + attrs['loss_weight_class'] * loss_class + return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class class TestYolov3LossOp(OpTest): def setUp(self): - self.loss_weight_xy = 1.0 - self.loss_weight_wh = 1.0 - self.loss_weight_conf_target = 1.0 - self.loss_weight_conf_notarget = 1.0 - self.loss_weight_class = 1.0 self.initTestCase() self.op_type = 'yolov3_loss' x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) @@ -192,11 +179,6 @@ class TestYolov3LossOp(OpTest): "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, "input_size": self.input_size, - "loss_weight_xy": self.loss_weight_xy, - "loss_weight_wh": self.loss_weight_wh, - "loss_weight_conf_target": self.loss_weight_conf_target, - "loss_weight_conf_notarget": self.loss_weight_conf_notarget, - "loss_weight_class": self.loss_weight_class, } self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} @@ -215,17 +197,12 @@ class TestYolov3LossOp(OpTest): max_relative_error=0.31) def initTestCase(self): - self.anchors = [12, 12] + self.anchors = [12, 12, 11, 13] self.class_num = 5 - self.ignore_thresh = 0.3 + self.ignore_thresh = 0.5 self.input_size = 416 self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 5, 4) - self.loss_weight_xy = 1.2 - self.loss_weight_wh = 0.8 - self.loss_weight_conf_target = 2.0 - self.loss_weight_conf_notarget = 1.0 - self.loss_weight_class = 1.5 if __name__ == "__main__": From 577a92d99203a67042f2b7fd6db25ecae09a1938 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 17 Dec 2018 11:45:16 +0800 Subject: [PATCH 06/24] use typename DeviceContext. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 12 +- paddle/fluid/operators/yolov3_loss_op.h | 301 ++++++------------ .../tests/unittests/test_yolov3_loss_op.py | 6 +- 3 files changed, 103 insertions(+), 216 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 3bd0db8b59..495a8f6c01 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -204,7 +204,11 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); -REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, - ops::Yolov3LossKernel); -REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, - ops::Yolov3LossGradKernel); +REGISTER_OP_CPU_KERNEL( + yolov3_loss, + ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL( + yolov3_loss_grad, + ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 5de5b4efc7..f086e89a99 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -13,6 +13,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -32,183 +33,6 @@ static inline bool isZero(T x) { return fabs(x) < 1e-6; } -template -static inline void CalcL1LossWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, - const T loss_weight, T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - loss[i] += fabs(y_data[j] - x_data[j]) * weight_data[j] * loss_weight; - } - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static void CalcL1LossGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& y, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = weight_data[j] * loss_grad[i]; - if (x_data[j] < y_data[j]) grad_data[j] *= -1.0; - } - grad_data += stride; - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, const T loss_weight, - T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - loss[i] += pow(y_data[j] - x_data[j], 2) * weight_data[j] * loss_weight; - } - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static void CalcMSEGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& y, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = - 2.0 * weight_data[j] * (x_data[j] - y_data[j]) * loss_grad[i]; - } - grad_data += stride; - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcSCEWithWeight(const Tensor& x, const Tensor& label, - const Tensor& weight, const T loss_weight, - T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* label_data = label.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - T term1 = (x_data[j] > 0) ? x_data[j] : 0; - T term2 = x_data[j] * label_data[j]; - T term3 = std::log(1.0 + std::exp(-std::abs(x_data[j]))); - loss[i] += (term1 - term2 + term3) * weight_data[j] * loss_weight; - } - x_data += stride; - label_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& label, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* label_data = label.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * - weight_data[j] * loss_grad[i]; - } - grad_data += stride; - x_data += stride; - label_data += stride; - weight_data += stride; - } -} - -// template -// static void SplitPredResult(const Tensor& input, Tensor* pred_conf, -// Tensor* pred_class, Tensor* pred_x, Tensor* -// pred_y, -// Tensor* pred_w, Tensor* pred_h, -// const int anchor_num, const int class_num) { -// const int n = input.dims()[0]; -// const int h = input.dims()[2]; -// const int w = input.dims()[3]; -// const int box_attr_num = 5 + class_num; -// -// auto input_t = EigenTensor::From(input); -// auto pred_conf_t = EigenTensor::From(*pred_conf); -// auto pred_class_t = EigenTensor::From(*pred_class); -// auto pred_x_t = EigenTensor::From(*pred_x); -// auto pred_y_t = EigenTensor::From(*pred_y); -// auto pred_w_t = EigenTensor::From(*pred_w); -// auto pred_h_t = EigenTensor::From(*pred_h); -// -// for (int i = 0; i < n; i++) { -// for (int an_idx = 0; an_idx < anchor_num; an_idx++) { -// for (int j = 0; j < h; j++) { -// for (int k = 0; k < w; k++) { -// pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, -// k); -// pred_y_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 1, j, k); -// pred_w_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 2, j, k); -// pred_h_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 3, j, k); -// -// pred_conf_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 4, j, k); -// -// for (int c = 0; c < class_num; c++) { -// pred_class_t(i, an_idx, j, k, c) = -// input_t(i, box_attr_num * an_idx + 5 + c, j, k); -// } -// } -// } -// } -// } -// } - template static T CalcBoxIoU(std::vector box1, std::vector box2) { T b1_x1 = box1[0] - box1[2] / 2; @@ -242,30 +66,36 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; const int b = gt_box.dims()[1]; - const int anchor_num = anchors.size() / 2; - auto gt_box_t = EigenTensor::From(gt_box); - auto gt_label_t = EigenTensor::From(gt_label); - auto conf_mask_t = EigenTensor::From(*conf_mask).setConstant(1.0); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0.0); - auto tx_t = EigenTensor::From(*tx).setConstant(0.0); - auto ty_t = EigenTensor::From(*ty).setConstant(0.0); - auto tw_t = EigenTensor::From(*tw).setConstant(0.0); - auto th_t = EigenTensor::From(*th).setConstant(0.0); - auto tweight_t = EigenTensor::From(*tweight).setConstant(0.0); - auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); - auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); + const int an_num = anchors.size() / 2; + const int h = tclass->dims()[2]; + const int w = tclass->dims()[3]; + const int class_num = tclass->dims()[4]; + + const T* gt_box_data = gt_box.data(); + const int* gt_label_data = gt_label.data(); + T* conf_mask_data = conf_mask->data(); + T* obj_mask_data = obj_mask->data(); + T* tx_data = tx->data(); + T* ty_data = ty->data(); + T* tw_data = tw->data(); + T* th_data = th->data(); + T* tweight_data = tweight->data(); + T* tconf_data = tconf->data(); + T* tclass_data = tclass->data(); for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + int box_idx = (i * b + j) * 4; + if (isZero(gt_box_data[box_idx + 2]) && + isZero(gt_box_data[box_idx + 3])) { continue; } - int cur_label = gt_label_t(i, j); - T gx = gt_box_t(i, j, 0) * grid_size; - T gy = gt_box_t(i, j, 1) * grid_size; - T gw = gt_box_t(i, j, 2) * input_size; - T gh = gt_box_t(i, j, 3) * input_size; + int cur_label = gt_label_data[i * b + j]; + T gx = gt_box_data[box_idx] * grid_size; + T gy = gt_box_data[box_idx + 1] * grid_size; + T gw = gt_box_data[box_idx + 2] * input_size; + T gh = gt_box_data[box_idx + 3] * input_size; int gi = static_cast(gx); int gj = static_cast(gy); @@ -273,7 +103,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, T iou; int best_an_index = -1; std::vector gt_box_shape({0, 0, gw, gh}); - for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + for (int an_idx = 0; an_idx < an_num; an_idx++) { std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), static_cast(anchors[2 * an_idx + 1])}); iou = CalcBoxIoU(gt_box_shape, anchor_shape); @@ -282,19 +112,22 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - conf_mask_t(i, an_idx, gj, gi) = static_cast(0.0); + int conf_idx = ((i * an_num + an_idx) * h + gj) * w + gi; + conf_mask_data[conf_idx] = static_cast(0.0); } } - conf_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - tx_t(i, best_an_index, gj, gi) = gx - gi; - ty_t(i, best_an_index, gj, gi) = gy - gj; - tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); - th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); - tweight_t(i, best_an_index, gj, gi) = - 2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3); - tclass_t(i, best_an_index, gj, gi, cur_label) = 1; - tconf_t(i, best_an_index, gj, gi) = 1; + + int obj_idx = ((i * an_num + best_an_index) * h + gj) * w + gi; + conf_mask_data[obj_idx] = static_cast(1.0); + obj_mask_data[obj_idx] = static_cast(1.0); + tx_data[obj_idx] = gx - gi; + ty_data[obj_idx] = gy - gj; + tw_data[obj_idx] = log(gw / anchors[2 * best_an_index]); + th_data[obj_idx] = log(gh / anchors[2 * best_an_index + 1]); + tweight_data[obj_idx] = + 2.0 - gt_box_data[box_idx + 2] * gt_box_data[box_idx + 3]; + tconf_data[obj_idx] = static_cast(1.0); + tclass_data[obj_idx * class_num + cur_label] = static_cast(1.0); } } } @@ -427,18 +260,26 @@ static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, const int class_num = tclass.dims()[4]; const int grid_num = h * w; + // T l = 0.0; CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); + // LOG(ERROR) << "C++ xy: " << loss_data[0] - l; + // l = loss_data[0]; CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); + // LOG(ERROR) << "C++ wh: " << loss_data[0] - l; + // l = loss_data[0]; CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, conf_mask_data, n, an_num, grid_num, class_num, 1); + // LOG(ERROR) << "C++ conf: " << loss_data[0] - l; + // l = loss_data[0]; CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, obj_mask_data, n, an_num, grid_num, class_num, class_num); + // LOG(ERROR) << "C++ class: " << loss_data[0] - l; } template @@ -488,7 +329,7 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } -template +template class Yolov3LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -517,6 +358,27 @@ class Yolov3LossKernel : public framework::OpKernel { tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + + math::SetConstant constant; + constant(ctx.template device_context(), &conf_mask, + static_cast(1.0)); + constant(ctx.template device_context(), &obj_mask, + static_cast(0.0)); + constant(ctx.template device_context(), &tx, + static_cast(0.0)); + constant(ctx.template device_context(), &ty, + static_cast(0.0)); + constant(ctx.template device_context(), &tw, + static_cast(0.0)); + constant(ctx.template device_context(), &th, + static_cast(0.0)); + constant(ctx.template device_context(), &tweight, + static_cast(0.0)); + constant(ctx.template device_context(), &tconf, + static_cast(0.0)); + constant(ctx.template device_context(), &tclass, + static_cast(0.0)); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); @@ -528,7 +390,7 @@ class Yolov3LossKernel : public framework::OpKernel { } }; -template +template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -559,6 +421,27 @@ class Yolov3LossGradKernel : public framework::OpKernel { tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + + math::SetConstant constant; + constant(ctx.template device_context(), &conf_mask, + static_cast(1.0)); + constant(ctx.template device_context(), &obj_mask, + static_cast(0.0)); + constant(ctx.template device_context(), &tx, + static_cast(0.0)); + constant(ctx.template device_context(), &ty, + static_cast(0.0)); + constant(ctx.template device_context(), &tw, + static_cast(0.0)); + constant(ctx.template device_context(), &th, + static_cast(0.0)); + constant(ctx.template device_context(), &tweight, + static_cast(0.0)); + constant(ctx.template device_context(), &tconf, + static_cast(0.0)); + constant(ctx.template device_context(), &tclass, + static_cast(0.0)); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index cf7e2c5289..862e77e663 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -197,12 +197,12 @@ class TestYolov3LossOp(OpTest): max_relative_error=0.31) def initTestCase(self): - self.anchors = [12, 12, 11, 13] + self.anchors = [12, 12] self.class_num = 5 self.ignore_thresh = 0.5 self.input_size = 416 - self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 5, 4) + self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3) + self.gtbox_shape = (1, 5, 4) if __name__ == "__main__": From db8ff57a61cbeec30b61111850b3e768661e8de8 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 17 Dec 2018 14:43:06 +0800 Subject: [PATCH 07/24] remove useless code and update doc. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 32 +++++----- paddle/fluid/operators/yolov3_loss_op.h | 64 ++++++++----------- python/paddle/fluid/layers/detection.py | 13 ---- .../tests/unittests/test_yolov3_loss_op.py | 5 -- 4 files changed, 45 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 495a8f6c01..aa4ba3b62e 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -138,17 +138,23 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { thresh, the confidence score loss of this anchor box will be ignored. Therefore, the yolov3 loss consist of three major parts, box location loss, - confidence score loss, and classification loss. The MSE loss is used for - box location, and binary cross entropy loss is used for confidence score - loss and classification loss. + confidence score loss, and classification loss. The L1 loss is used for + box coordinates (w, h), and sigmoid cross entropy loss is used for box + coordinates (x, y), confidence score loss and classification loss. + + In order to trade off box coordinate losses between big boxes and small + boxes, box coordinate losses will be mutiplied by scale weight, which is + calculated as follow. + + $$ + weight_{box} = 2.0 - t_w * t_h + $$ Final loss will be represented as follow. $$ - loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh} - + \loss_weight_{conf_target} * loss_{conf_target} - + \loss_weight_{conf_notarget} * loss_{conf_notarget} - + \loss_weight_{class} * loss_{class} + loss = (loss_{xy} + loss_{wh}) * weight_{box} + + loss_{conf} + loss_{class} $$ )DOC"); } @@ -204,11 +210,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); -REGISTER_OP_CPU_KERNEL( - yolov3_loss, - ops::Yolov3LossKernel, - ops::Yolov3LossKernel); -REGISTER_OP_CPU_KERNEL( - yolov3_loss_grad, - ops::Yolov3LossGradKernel, - ops::Yolov3LossGradKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index f086e89a99..e32cd30967 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -260,26 +260,18 @@ static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, const int class_num = tclass.dims()[4]; const int grid_num = h * w; - // T l = 0.0; CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); - // LOG(ERROR) << "C++ xy: " << loss_data[0] - l; - // l = loss_data[0]; CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); - // LOG(ERROR) << "C++ wh: " << loss_data[0] - l; - // l = loss_data[0]; CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, conf_mask_data, n, an_num, grid_num, class_num, 1); - // LOG(ERROR) << "C++ conf: " << loss_data[0] - l; - // l = loss_data[0]; CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, obj_mask_data, n, an_num, grid_num, class_num, class_num); - // LOG(ERROR) << "C++ class: " << loss_data[0] - l; } template @@ -329,7 +321,7 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } -template +template class Yolov3LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -359,24 +351,24 @@ class Yolov3LossKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - math::SetConstant constant; - constant(ctx.template device_context(), &conf_mask, - static_cast(1.0)); - constant(ctx.template device_context(), &obj_mask, - static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, + math::SetConstant constant; + constant(ctx.template device_context(), + &conf_mask, static_cast(1.0)); + constant(ctx.template device_context(), + &obj_mask, static_cast(0.0)); + constant(ctx.template device_context(), &tx, static_cast(0.0)); - constant(ctx.template device_context(), &tw, + constant(ctx.template device_context(), &ty, static_cast(0.0)); - constant(ctx.template device_context(), &th, + constant(ctx.template device_context(), &tw, static_cast(0.0)); - constant(ctx.template device_context(), &tweight, + constant(ctx.template device_context(), &th, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, + constant(ctx.template device_context(), + &tweight, static_cast(0.0)); + constant(ctx.template device_context(), &tconf, static_cast(0.0)); - constant(ctx.template device_context(), &tclass, + constant(ctx.template device_context(), &tclass, static_cast(0.0)); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, @@ -390,7 +382,7 @@ class Yolov3LossKernel : public framework::OpKernel { } }; -template +template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -422,24 +414,24 @@ class Yolov3LossGradKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - math::SetConstant constant; - constant(ctx.template device_context(), &conf_mask, - static_cast(1.0)); - constant(ctx.template device_context(), &obj_mask, - static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, + math::SetConstant constant; + constant(ctx.template device_context(), + &conf_mask, static_cast(1.0)); + constant(ctx.template device_context(), + &obj_mask, static_cast(0.0)); + constant(ctx.template device_context(), &tx, static_cast(0.0)); - constant(ctx.template device_context(), &tw, + constant(ctx.template device_context(), &ty, static_cast(0.0)); - constant(ctx.template device_context(), &th, + constant(ctx.template device_context(), &tw, static_cast(0.0)); - constant(ctx.template device_context(), &tweight, + constant(ctx.template device_context(), &th, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, + constant(ctx.template device_context(), + &tweight, static_cast(0.0)); + constant(ctx.template device_context(), &tconf, static_cast(0.0)); - constant(ctx.template device_context(), &tclass, + constant(ctx.template device_context(), &tclass, static_cast(0.0)); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index caa9b1c3d4..92823af1e0 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -485,19 +485,6 @@ def yolov3_loss(x, "input_size": input_size, } - # if loss_weight_xy is not None and isinstance(loss_weight_xy, float): - # self.attrs['loss_weight_xy'] = loss_weight_xy - # if loss_weight_wh is not None and isinstance(loss_weight_wh, float): - # self.attrs['loss_weight_wh'] = loss_weight_wh - # if loss_weight_conf_target is not None and isinstance( - # loss_weight_conf_target, float): - # self.attrs['loss_weight_conf_target'] = loss_weight_conf_target - # if loss_weight_conf_notarget is not None and isinstance( - # loss_weight_conf_notarget, float): - # self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget - # if loss_weight_class is not None and isinstance(loss_weight_class, float): - # self.attrs['loss_weight_class'] = loss_weight_class - helper.append_op( type='yolov3_loss', inputs={"X": x, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 862e77e663..e52047b0ad 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -157,11 +157,6 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): loss_obj = sce(pred_conf, tconf, conf_mask) loss_class = sce(pred_cls, tcls, obj_mask_expand) - # print("python loss_xy: ", loss_x + loss_y) - # print("python loss_wh: ", loss_w + loss_h) - # print("python loss_obj: ", loss_obj) - # print("python loss_class: ", loss_class) - return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class From bd6deb1a8bc0b39cde425117b6c6048f4a945a7f Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 17 Dec 2018 15:09:56 +0800 Subject: [PATCH 08/24] fix API.spec change. test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 4acccd0899..f293b0d30e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'input_size', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) From e7e4f084e51a3f3a91a32b9eb03bff71963f9e45 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 20 Dec 2018 21:34:05 +0800 Subject: [PATCH 09/24] ignore pred overlap gt > 0.7. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 35 +- paddle/fluid/operators/yolov3_loss_op.h | 556 +++++++++++++++--- python/paddle/fluid/layers/detection.py | 14 +- python/paddle/fluid/tests/test_detection.py | 4 +- .../tests/unittests/test_yolov3_loss_op.py | 184 +++++- 5 files changed, 668 insertions(+), 125 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index aa4ba3b62e..8c46e341d6 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -35,13 +35,16 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; + auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); + int mask_num = anchor_mask.size(); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); - PADDLE_ENFORCE_EQ(dim_x[1], anchor_num * (5 + class_num), - "Input(X) dim[1] should be equal to (anchor_number * (5 " - "+ class_num))."); + PADDLE_ENFORCE_EQ( + dim_x[1], mask_num * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " + "+ class_num))."); PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); @@ -55,6 +58,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, "Attr(anchors) length should be even integer."); + for (size_t i = 0; i < anchor_mask.size(); i++) { + PADDLE_ENFORCE_LT( + anchor_mask[i], anchor_num, + "Attr(anchor_mask) should not crossover Attr(anchors)."); + } PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); @@ -74,7 +82,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "The input tensor of YOLO v3 loss operator, " + "The input tensor of YOLOv3 loss operator, " "This is a 4-D tensor with shape of [N, C, H, W]." "H and W should be same, and the second dimention(C) stores" "box locations, confidence score and classification one-hot" @@ -99,13 +107,20 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", "The anchor width and height, " - "it will be parsed pair by pair."); - AddAttr("input_size", - "The input size of YOLOv3 net, " - "generally this is set as 320, 416 or 608.") - .SetDefault(406); + "it will be parsed pair by pair.") + .SetDefault(std::vector{}); + AddAttr>("anchor_mask", + "The mask index of anchors used in " + "current YOLOv3 loss calculation.") + .SetDefault(std::vector{}); + AddAttr("downsample", + "The downsample ratio from network input to YOLOv3 loss " + "input, so 32, 16, 8 should be set for the first, second, " + "and thrid YOLOv3 loss operators.") + .SetDefault(32); AddAttr("ignore_thresh", - "The ignore threshold to ignore confidence loss."); + "The ignore threshold to ignore confidence loss.") + .SetDefault(0.7); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index e32cd30967..9254a6cf6f 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -321,6 +321,182 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } +static int mask_index(std::vector mask, int val) { + for (int i = 0; i < mask.size(); i++) { + if (mask[i] == val) { + return i; + } + } + return -1; +} + +template +struct Box { + float x, y, w, h; +}; + +template +static inline T sigmoid(T x) { + return 1.0 / (1.0 + std::exp(-x)); +} + +template +static inline void sigmoid_arrray(T* arr, int len) { + for (int i = 0; i < len; i++) { + arr[i] = sigmoid(arr[i]); + } +} + +template +static inline Box get_yolo_box(const T* x, std::vector anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride) { + Box b; + b.x = (i + sigmoid(x[index])) / grid_size; + b.y = (j + sigmoid(x[index + stride])) / grid_size; + b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] / input_size; + b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] / input_size; + return b; +} + +template +static inline Box get_gt_box(const T* gt, int batch, int max_boxes, + int idx) { + Box b; + b.x = gt[(batch * max_boxes + idx) * 4]; + b.y = gt[(batch * max_boxes + idx) * 4 + 1]; + b.w = gt[(batch * max_boxes + idx) * 4 + 2]; + b.h = gt[(batch * max_boxes + idx) * 4 + 3]; + return b; +} + +template +static inline T overlap(T c1, T w1, T c2, T w2) { + T l1 = c1 - w1 / 2.0; + T l2 = c2 - w2 / 2.0; + T left = l1 > l2 ? l1 : l2; + T r1 = c1 + w1 / 2.0; + T r2 = c2 + w2 / 2.0; + T right = r1 < r2 ? r1 : r2; + return right - left; +} + +template +static inline T box_iou(Box b1, Box b2) { + T w = overlap(b1.x, b1.w, b2.x, b2.w); + T h = overlap(b1.y, b1.h, b2.y, b2.h); + T inter_area = (w < 0 || h < 0) ? 0.0 : w * h; + T union_area = b1.w * b1.h + b2.w * b2.h - inter_area; + return inter_area / union_area; +} + +static inline int entry_index(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride, int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +template +static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, + std::vector anchors, int an_idx, + int box_idx, int gi, int gj, int grid_size, + int input_size, int stride) { + T tx = gt.x * grid_size - gi; + T ty = gt.y * grid_size - gj; + T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); + T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); + + T scale = 2.0 - gt.w * gt.h; + loss[0] += SCE(input[box_idx], tx) * scale; + loss[0] += SCE(input[box_idx + stride], ty) * scale; + loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; + loss[0] += L1Loss(input[box_idx + 3 * stride], th) * scale; +} + +template +static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, + Box gt, std::vector anchors, + int an_idx, int box_idx, int gi, int gj, + int grid_size, int input_size, int stride) { + T tx = gt.x * grid_size - gi; + T ty = gt.y * grid_size - gj; + T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); + T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); + + T scale = 2.0 - gt.w * gt.h; + input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; + input_grad[box_idx + stride] = + SCEGrad(input[box_idx + stride], ty) * scale * loss; + input_grad[box_idx + 2 * stride] = + L1LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; + input_grad[box_idx + 3 * stride] = + L1LossGrad(input[box_idx + 3 * stride], th) * scale * loss; +} + +template +static inline void CalcLabelLoss(T* loss, const T* input, const int index, + const int label, const int class_num, + const int stride) { + for (int i = 0; i < class_num; i++) { + loss[0] += SCE(input[index + i * stride], (i == label) ? 1.0 : 0.0); + } +} + +template +static inline void CalcLabelLossGrad(T* input_grad, const T loss, + const T* input, const int index, + const int label, const int class_num, + const int stride) { + for (int i = 0; i < class_num; i++) { + input_grad[index + i * stride] = + SCEGrad(input[index + i * stride], (i == label) ? 1.0 : 0.0) * loss; + } +} + +template +static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, + const int n, const int an_num, const int h, + const int w, const int stride, + const int an_stride) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj = objness[k * w + l]; + if (obj >= 0) { + loss[i] += SCE(input[k * w + l], static_cast(obj)); + } + } + } + objness += stride; + input += an_stride; + } + } +} + +template +static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, + const T* input, const int* objness, + const int n, const int an_num, + const int h, const int w, + const int stride, const int an_stride) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj = objness[k * w + l]; + if (obj >= 0) { + input_grad[k * w + l] = + SCEGrad(input[k * w + l], static_cast(obj)) * loss[i]; + } + } + } + objness += stride; + input += an_stride; + input_grad += an_stride; + } + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -330,55 +506,158 @@ class Yolov3LossKernel : public framework::OpKernel { auto* gt_label = ctx.Input("GTLabel"); auto* loss = ctx.Output("Loss"); auto anchors = ctx.Attr>("anchors"); + auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); + int downsample = ctx.Attr("downsample"); const int n = input->dims()[0]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; + const int mask_num = anchor_mask.size(); + const int b = gt_box->dims()[1]; + int input_size = downsample * h; - Tensor conf_mask, obj_mask; - Tensor tx, ty, tw, th, tweight, tconf, tclass; - conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - - math::SetConstant constant; - constant(ctx.template device_context(), - &conf_mask, static_cast(1.0)); - constant(ctx.template device_context(), - &obj_mask, static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, - static_cast(0.0)); - constant(ctx.template device_context(), &tw, - static_cast(0.0)); - constant(ctx.template device_context(), &th, - static_cast(0.0)); - constant(ctx.template device_context(), - &tweight, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, - static_cast(0.0)); - constant(ctx.template device_context(), &tclass, - static_cast(0.0)); - - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, - &tconf, &tclass); - + const T* input_data = input->data(); + const T* gt_box_data = gt_box->data(); + const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); - memset(loss_data, 0, n * sizeof(T)); - CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, tclass, - conf_mask, obj_mask); + memset(loss_data, 0, n * sizeof(int)); + + Tensor objness; + int* objness_data = + objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(objness_data, 0, objness.numel() * sizeof(int)); + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < mask_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int box_idx = + entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = + get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, + input_size, box_idx, stride); + T best_iou = 0; + // int best_t = 0; + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + T iou = box_iou(pred, gt); + if (iou > best_iou) { + best_iou = iou; + // best_t = t; + } + } + + if (best_iou > ignore_thresh) { + int obj_idx = (i * mask_num + j) * stride + k * w + l; + objness_data[obj_idx] = -1; + } + } + } + } + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + Box gt_shift = gt; + gt_shift.x = 0.0; + gt_shift.y = 0.0; + T best_iou = 0.0; + int best_n = 0; + for (int an_idx = 0; an_idx < an_num; an_idx++) { + Box an_box; + an_box.x = 0.0; + an_box.y = 0.0; + an_box.w = anchors[2 * an_idx] / static_cast(input_size); + an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); + float iou = box_iou(an_box, gt_shift); + // TO DO: iou > 0.5 ? + if (iou > best_iou) { + best_iou = iou; + best_n = an_idx; + } + } + + int mask_idx = mask_index(anchor_mask, best_n); + if (mask_idx >= 0) { + int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); + CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, + box_idx, gi, gj, h, input_size, stride); + + int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; + objness_data[obj_idx] = 1; + + int label = gt_label_data[i * b + t]; + int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); + CalcLabelLoss(loss_data + i, input_data, label_idx, label, + class_num, stride); + } + } + } + + CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, + mask_num, h, w, stride, an_stride); + + // Tensor conf_mask, obj_mask; + // Tensor tx, ty, tw, th, tweight, tconf, tclass; + // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + // + // math::SetConstant constant; + // constant(ctx.template device_context(), + // &conf_mask, static_cast(1.0)); + // constant(ctx.template device_context(), + // &obj_mask, static_cast(0.0)); + // constant(ctx.template device_context(), &tx, + // static_cast(0.0)); + // constant(ctx.template device_context(), &ty, + // static_cast(0.0)); + // constant(ctx.template device_context(), &tw, + // static_cast(0.0)); + // constant(ctx.template device_context(), &th, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tweight, static_cast(0.0)); + // constant(ctx.template device_context(), + // &tconf, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tclass, + // static_cast(0.0)); + // + // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, + // input_size, + // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, + // &tweight, + // &tconf, &tclass); + // + // T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); + // memset(loss_data, 0, n * sizeof(T)); + // CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, + // tclass, + // conf_mask, obj_mask); } }; @@ -389,59 +668,172 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto anchors = ctx.Attr>("anchors"); + auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); - int input_size = ctx.Attr("input_size"); + int downsample = ctx.Attr("downsample"); const int n = input->dims()[0]; const int c = input->dims()[1]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - - Tensor conf_mask, obj_mask; - Tensor tx, ty, tw, th, tweight, tconf, tclass; - conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - - math::SetConstant constant; - constant(ctx.template device_context(), - &conf_mask, static_cast(1.0)); - constant(ctx.template device_context(), - &obj_mask, static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, - static_cast(0.0)); - constant(ctx.template device_context(), &tw, - static_cast(0.0)); - constant(ctx.template device_context(), &th, - static_cast(0.0)); - constant(ctx.template device_context(), - &tweight, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, - static_cast(0.0)); - constant(ctx.template device_context(), &tclass, - static_cast(0.0)); - - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, - &tconf, &tclass); - + const int mask_num = anchor_mask.size(); + const int b = gt_box->dims()[1]; + int input_size = downsample * h; + + const T* input_data = input->data(); + const T* gt_box_data = gt_box->data(); + const int* gt_label_data = gt_label->data(); + const T* loss_grad_data = loss_grad->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, th, - tweight, tconf, tclass, conf_mask, obj_mask); + memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); + + Tensor objness; + int* objness_data = + objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(objness_data, 0, objness.numel() * sizeof(int)); + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < mask_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int box_idx = + entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = + get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, + input_size, box_idx, stride); + T best_iou = 0; + // int best_t = 0; + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + T iou = box_iou(pred, gt); + if (iou > best_iou) { + best_iou = iou; + // best_t = t; + } + } + + if (best_iou > ignore_thresh) { + int obj_idx = (i * mask_num + j) * stride + k * w + l; + objness_data[obj_idx] = -1; + } + } + } + } + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + Box gt_shift = gt; + gt_shift.x = 0.0; + gt_shift.y = 0.0; + T best_iou = 0.0; + int best_n = 0; + for (int an_idx = 0; an_idx < an_num; an_idx++) { + Box an_box; + an_box.x = 0.0; + an_box.y = 0.0; + an_box.w = anchors[2 * an_idx] / static_cast(input_size); + an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); + float iou = box_iou(an_box, gt_shift); + // TO DO: iou > 0.5 ? + if (iou > best_iou) { + best_iou = iou; + best_n = an_idx; + } + } + + int mask_idx = mask_index(anchor_mask, best_n); + if (mask_idx >= 0) { + int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); + CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], + input_data, gt, anchors, best_n, box_idx, + gi, gj, h, input_size, stride); + + int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; + objness_data[obj_idx] = 1; + + int label = gt_label_data[i * b + t]; + int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); + CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, + label_idx, label, class_num, stride); + } + } + } + + CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, + input_data + 4 * stride, objness_data, n, mask_num, + h, w, stride, an_stride); + + // const int n = input->dims()[0]; + // const int c = input->dims()[1]; + // const int h = input->dims()[2]; + // const int w = input->dims()[3]; + // const int an_num = anchors.size() / 2; + // + // Tensor conf_mask, obj_mask; + // Tensor tx, ty, tw, th, tweight, tconf, tclass; + // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + // + // math::SetConstant constant; + // constant(ctx.template device_context(), + // &conf_mask, static_cast(1.0)); + // constant(ctx.template device_context(), + // &obj_mask, static_cast(0.0)); + // constant(ctx.template device_context(), &tx, + // static_cast(0.0)); + // constant(ctx.template device_context(), &ty, + // static_cast(0.0)); + // constant(ctx.template device_context(), &tw, + // static_cast(0.0)); + // constant(ctx.template device_context(), &th, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tweight, static_cast(0.0)); + // constant(ctx.template device_context(), + // &tconf, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tclass, + // static_cast(0.0)); + // + // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, + // input_size, + // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, + // &tweight, + // &tconf, &tclass); + // + // T* input_grad_data = + // input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + // CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, + // th, + // tweight, tconf, tclass, conf_mask, obj_mask); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 92823af1e0..542162b7f4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -413,9 +413,10 @@ def yolov3_loss(x, gtbox, gtlabel, anchors, + anchor_mask, class_num, ignore_thresh, - input_size, + downsample, name=None): """ ${comment} @@ -430,9 +431,10 @@ def yolov3_loss(x, gtlabel (Variable): class id of ground truth boxes, shoud be ins shape of [N, B]. anchors (list|tuple): ${anchors_comment} + anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} - input_size (int): ${input_size_comment} + downsample (int): ${downsample_comment} name (string): the name of yolov3 loss Returns: @@ -452,7 +454,8 @@ def yolov3_loss(x, x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') - anchors = [10, 13, 16, 30, 33, 23] + anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] + anchors = [0, 1, 2] loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 anchors=anchors, ignore_thresh=0.5) """ @@ -466,6 +469,8 @@ def yolov3_loss(x, raise TypeError("Input gtlabel of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): + raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") if not isinstance(ignore_thresh, float): @@ -480,9 +485,10 @@ def yolov3_loss(x, attrs = { "anchors": anchors, + "anchor_mask": anchor_mask, "class_num": class_num, "ignore_thresh": ignore_thresh, - "input_size": input_size, + "downsample": downsample, } helper.append_op( diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 7d75562900..e11205d2bf 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -463,8 +463,8 @@ class TestYoloDetection(unittest.TestCase): x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') - loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, - 0.7, 416) + loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], + [0, 1], 10, 0.7, 32) self.assertIsNotNone(loss) diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index e52047b0ad..3cada49647 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -22,32 +22,42 @@ from op_test import OpTest from paddle.fluid import core - -def l1loss(x, y, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - y = y.reshape((n, -1)) - weight = weight.reshape((n, -1)) - return (np.abs(y - x) * weight).sum(axis=1) +# def l1loss(x, y, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# y = y.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# return (np.abs(y - x) * weight).sum(axis=1) +# +# +# def mse(x, y, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# y = y.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# return ((y - x)**2 * weight).sum(axis=1) +# +# +# def sce(x, label, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# label = label.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# sigmoid_x = expit(x) +# term1 = label * np.log(sigmoid_x) +# term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) +# return ((-term1 - term2) * weight).sum(axis=1) -def mse(x, y, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - y = y.reshape((n, -1)) - weight = weight.reshape((n, -1)) - return ((y - x)**2 * weight).sum(axis=1) +def l1loss(x, y): + return abs(x - y) -def sce(x, label, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - label = label.reshape((n, -1)) - weight = weight.reshape((n, -1)) +def sce(x, label): sigmoid_x = expit(x) term1 = label * np.log(sigmoid_x) term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) - return ((-term1 - term2) * weight).sum(axis=1) + return -term1 - term2 def box_iou(box1, box2): @@ -160,6 +170,121 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def batch_xywh_box_iou(box1, box2): + b1_left = box1[:, :, 0] - box1[:, :, 2] / 2 + b1_right = box1[:, :, 0] + box1[:, :, 2] / 2 + b1_top = box1[:, :, 1] - box1[:, :, 3] / 2 + b1_bottom = box1[:, :, 1] + box1[:, :, 3] / 2 + + b2_left = box2[:, :, 0] - box2[:, :, 2] / 2 + b2_right = box2[:, :, 0] + box2[:, :, 2] / 2 + b2_top = box2[:, :, 1] - box2[:, :, 3] / 2 + b2_bottom = box2[:, :, 1] + box2[:, :, 3] / 2 + + left = np.maximum(b1_left[:, :, np.newaxis], b2_left[:, np.newaxis, :]) + right = np.minimum(b1_right[:, :, np.newaxis], b2_right[:, np.newaxis, :]) + top = np.maximum(b1_top[:, :, np.newaxis], b2_top[:, np.newaxis, :]) + bottom = np.minimum(b1_bottom[:, :, np.newaxis], + b2_bottom[:, np.newaxis, :]) + + inter_w = np.clip(right - left, 0., 1.) + inter_h = np.clip(bottom - top, 0., 1.) + inter_area = inter_w * inter_h + + b1_area = (b1_right - b1_left) * (b1_bottom - b1_top) + b2_area = (b2_right - b2_left) * (b2_bottom - b2_top) + union = b1_area[:, :, np.newaxis] + b2_area[:, np.newaxis, :] - inter_area + + return inter_area / union + + +def YOLOv3Loss(x, gtbox, gtlabel, attrs): + n, c, h, w = x.shape + b = gtbox.shape[1] + anchors = attrs['anchors'] + an_num = len(anchors) // 2 + anchor_mask = attrs['anchor_mask'] + mask_num = len(anchor_mask) + class_num = attrs["class_num"] + ignore_thresh = attrs['ignore_thresh'] + downsample = attrs['downsample'] + input_size = downsample * h + x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + loss = np.zeros((n)).astype('float32') + + pred_box = x[:, :, :, :, :4].copy() + grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) + grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) + pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w + pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + + mask_anchors = [] + for m in anchor_mask: + mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in mask_anchors]) + anchor_w = anchors_s[:, 0:1].reshape((1, mask_num, 1, 1)) + anchor_h = anchors_s[:, 1:2].reshape((1, mask_num, 1, 1)) + pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w + pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h + + pred_box = pred_box.reshape((n, -1, 4)) + pred_obj = x[:, :, :, :, 4].reshape((n, -1)) + objness = np.zeros(pred_box.shape[:2]) + ious = batch_xywh_box_iou(pred_box, gtbox) + ious_max = np.max(ious, axis=-1) + objness = np.where(ious_max > ignore_thresh, -np.ones_like(objness), + objness) + + gtbox_shift = gtbox.copy() + gtbox_shift[:, :, 0] = 0 + gtbox_shift[:, :, 1] = 0 + + anchors = [(anchors[2 * i], anchors[2 * i + 1]) for i in range(0, an_num)] + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in anchors]) + anchor_boxes = np.concatenate( + [np.zeros_like(anchors_s), anchors_s], axis=-1) + anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1)) + ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes) + iou_matches = np.argmax(ious, axis=-1) + for i in range(n): + for j in range(b): + if gtbox[i, j, 2:].sum() == 0: + continue + if iou_matches[i, j] not in anchor_mask: + continue + an_idx = anchor_mask.index(iou_matches[i, j]) + gi = int(gtbox[i, j, 0] * w) + gj = int(gtbox[i, j, 1] * h) + + tx = gtbox[i, j, 0] * w - gi + ty = gtbox[i, j, 1] * w - gj + tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) + th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) + scale = 2.0 - gtbox[i, j, 2] * gtbox[i, j, 3] + loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale + loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale + + objness[i, an_idx * h * w + gj * w + gi] = 1 + + for label_idx in range(class_num): + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], + int(label_idx == gtlabel[i, j])) + + for j in range(mask_num * h * w): + if objness[i, j] >= 0: + loss[i] += sce(pred_obj[i, j], objness[i, j]) + + return loss + + class TestYolov3LossOp(OpTest): def setUp(self): self.initTestCase() @@ -171,13 +296,14 @@ class TestYolov3LossOp(OpTest): self.attrs = { "anchors": self.anchors, + "anchor_mask": self.anchor_mask, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, - "input_size": self.input_size, + "downsample": self.downsample, } self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} - self.outputs = {'Loss': YoloV3Loss(x, gtbox, gtlabel, self.attrs)} + self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} def test_check_output(self): place = core.CPUPlace() @@ -189,15 +315,19 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.31) + max_relative_error=0.15) def initTestCase(self): - self.anchors = [12, 12] + self.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, + 373, 326 + ] + self.anchor_mask = [0, 1, 2] self.class_num = 5 - self.ignore_thresh = 0.5 - self.input_size = 416 - self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3) - self.gtbox_shape = (1, 5, 4) + self.ignore_thresh = 0.7 + self.downsample = 32 + self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 10, 4) if __name__ == "__main__": From 6c5a5d078920d7be79e5346e5cc6870b1b6b3aa3 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 21 Dec 2018 12:13:57 +0800 Subject: [PATCH 10/24] format code. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.h | 472 ++---------------- .../tests/unittests/test_yolov3_loss_op.py | 148 +----- 3 files changed, 53 insertions(+), 569 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f293b0d30e..6c6ac9c7ea 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'input_size', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 9254a6cf6f..12499befca 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -26,110 +26,9 @@ template using EigenVector = framework::EigenVector; -using Array5 = Eigen::DSizes; - -template -static inline bool isZero(T x) { - return fabs(x) < 1e-6; -} - template -static T CalcBoxIoU(std::vector box1, std::vector box2) { - T b1_x1 = box1[0] - box1[2] / 2; - T b1_x2 = box1[0] + box1[2] / 2; - T b1_y1 = box1[1] - box1[3] / 2; - T b1_y2 = box1[1] + box1[3] / 2; - T b2_x1 = box2[0] - box2[2] / 2; - T b2_x2 = box2[0] + box2[2] / 2; - T b2_y1 = box2[1] - box2[3] / 2; - T b2_y2 = box2[1] + box2[3] / 2; - - T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1); - T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1); - - T inter_rect_x1 = std::max(b1_x1, b2_x1); - T inter_rect_y1 = std::max(b1_y1, b2_y1); - T inter_rect_x2 = std::min(b1_x2, b2_x2); - T inter_rect_y2 = std::min(b1_y2, b2_y2); - T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast(0.0)) * - std::max(inter_rect_y2 - inter_rect_y1, static_cast(0.0)); - - return inter_area / (b1_area + b2_area - inter_area); -} - -template -static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, - const float ignore_thresh, std::vector anchors, - const int input_size, const int grid_size, - Tensor* conf_mask, Tensor* obj_mask, Tensor* tx, - Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight, - Tensor* tconf, Tensor* tclass) { - const int n = gt_box.dims()[0]; - const int b = gt_box.dims()[1]; - const int an_num = anchors.size() / 2; - const int h = tclass->dims()[2]; - const int w = tclass->dims()[3]; - const int class_num = tclass->dims()[4]; - - const T* gt_box_data = gt_box.data(); - const int* gt_label_data = gt_label.data(); - T* conf_mask_data = conf_mask->data(); - T* obj_mask_data = obj_mask->data(); - T* tx_data = tx->data(); - T* ty_data = ty->data(); - T* tw_data = tw->data(); - T* th_data = th->data(); - T* tweight_data = tweight->data(); - T* tconf_data = tconf->data(); - T* tclass_data = tclass->data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < b; j++) { - int box_idx = (i * b + j) * 4; - if (isZero(gt_box_data[box_idx + 2]) && - isZero(gt_box_data[box_idx + 3])) { - continue; - } - - int cur_label = gt_label_data[i * b + j]; - T gx = gt_box_data[box_idx] * grid_size; - T gy = gt_box_data[box_idx + 1] * grid_size; - T gw = gt_box_data[box_idx + 2] * input_size; - T gh = gt_box_data[box_idx + 3] * input_size; - int gi = static_cast(gx); - int gj = static_cast(gy); - - T max_iou = static_cast(0); - T iou; - int best_an_index = -1; - std::vector gt_box_shape({0, 0, gw, gh}); - for (int an_idx = 0; an_idx < an_num; an_idx++) { - std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), - static_cast(anchors[2 * an_idx + 1])}); - iou = CalcBoxIoU(gt_box_shape, anchor_shape); - if (iou > max_iou) { - max_iou = iou; - best_an_index = an_idx; - } - if (iou > ignore_thresh) { - int conf_idx = ((i * an_num + an_idx) * h + gj) * w + gi; - conf_mask_data[conf_idx] = static_cast(0.0); - } - } - - int obj_idx = ((i * an_num + best_an_index) * h + gj) * w + gi; - conf_mask_data[obj_idx] = static_cast(1.0); - obj_mask_data[obj_idx] = static_cast(1.0); - tx_data[obj_idx] = gx - gi; - ty_data[obj_idx] = gy - gj; - tw_data[obj_idx] = log(gw / anchors[2 * best_an_index]); - th_data[obj_idx] = log(gh / anchors[2 * best_an_index + 1]); - tweight_data[obj_idx] = - 2.0 - gt_box_data[box_idx + 2] * gt_box_data[box_idx + 3]; - tconf_data[obj_idx] = static_cast(1.0); - tclass_data[obj_idx * class_num + cur_label] = static_cast(1.0); - } - } +static inline bool LessEqualZero(T x) { + return x < 1e-6; } template @@ -152,177 +51,8 @@ static T L1LossGrad(T x, T y) { return x > y ? 1.0 : -1.0; } -template -static void CalcSCE(T* loss_data, const T* input, const T* target, - const T* weight, const T* mask, const int n, - const int an_num, const int grid_num, const int class_num, - const int num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - int sub_idx = k * num; - for (int l = 0; l < num; l++) { - loss_data[i] += SCE(input[l * grid_num + k], target[sub_idx + l]) * - weight[k] * mask[k]; - } - } - input += (class_num + 5) * grid_num; - target += grid_num * num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcSCEGrad(T* input_grad, const T* loss_grad, const T* input, - const T* target, const T* weight, const T* mask, - const int n, const int an_num, const int grid_num, - const int class_num, const int num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - int sub_idx = k * num; - for (int l = 0; l < num; l++) { - input_grad[l * grid_num + k] = - SCEGrad(input[l * grid_num + k], target[sub_idx + l]) * - weight[k] * mask[k] * loss_grad[i]; - } - } - input_grad += (class_num + 5) * grid_num; - input += (class_num + 5) * grid_num; - target += grid_num * num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcL1Loss(T* loss_data, const T* input, const T* target, - const T* weight, const T* mask, const int n, - const int an_num, const int grid_num, - const int class_num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - loss_data[i] += L1Loss(input[k], target[k]) * weight[k] * mask[k]; - } - input += (class_num + 5) * grid_num; - target += grid_num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcL1LossGrad(T* input_grad, const T* loss_grad, const T* input, - const T* target, const T* weight, const T* mask, - const int n, const int an_num, const int grid_num, - const int class_num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - input_grad[k] = L1LossGrad(input[k], target[k]) * weight[k] * - mask[k] * loss_grad[i]; - } - input_grad += (class_num + 5) * grid_num; - input += (class_num + 5) * grid_num; - target += grid_num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, - const Tensor& ty, const Tensor& tw, const Tensor& th, - const Tensor& tweight, const Tensor& tconf, - const Tensor& tclass, const Tensor& conf_mask, - const Tensor& obj_mask) { - const T* input_data = input.data(); - const T* tx_data = tx.data(); - const T* ty_data = ty.data(); - const T* tw_data = tw.data(); - const T* th_data = th.data(); - const T* tweight_data = tweight.data(); - const T* tconf_data = tconf.data(); - const T* tclass_data = tclass.data(); - const T* conf_mask_data = conf_mask.data(); - const T* obj_mask_data = obj_mask.data(); - - const int n = tclass.dims()[0]; - const int an_num = tclass.dims()[1]; - const int h = tclass.dims()[2]; - const int w = tclass.dims()[3]; - const int class_num = tclass.dims()[4]; - const int grid_num = h * w; - - CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, - an_num, grid_num, class_num, 1); - CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num, 1); - CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, - conf_mask_data, n, an_num, grid_num, class_num, 1); - CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, - obj_mask_data, n, an_num, grid_num, class_num, class_num); -} - -template -static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, - const Tensor& input, const Tensor& tx, - const Tensor& ty, const Tensor& tw, - const Tensor& th, const Tensor& tweight, - const Tensor& tconf, const Tensor& tclass, - const Tensor& conf_mask, - const Tensor& obj_mask) { - const T* loss_grad_data = loss_grad.data(); - const T* input_data = input.data(); - const T* tx_data = tx.data(); - const T* ty_data = ty.data(); - const T* tw_data = tw.data(); - const T* th_data = th.data(); - const T* tweight_data = tweight.data(); - const T* tconf_data = tconf.data(); - const T* tclass_data = tclass.data(); - const T* conf_mask_data = conf_mask.data(); - const T* obj_mask_data = obj_mask.data(); - - const int n = tclass.dims()[0]; - const int an_num = tclass.dims()[1]; - const int h = tclass.dims()[2]; - const int w = tclass.dims()[3]; - const int class_num = tclass.dims()[4]; - const int grid_num = h * w; - - CalcSCEGrad(input_grad_data, loss_grad_data, input_data, tx_data, - tweight_data, obj_mask_data, n, an_num, grid_num, class_num, - 1); - CalcSCEGrad(input_grad_data + grid_num, loss_grad_data, - input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, - an_num, grid_num, class_num, 1); - CalcL1LossGrad(input_grad_data + 2 * grid_num, loss_grad_data, - input_data + 2 * grid_num, tw_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcL1LossGrad(input_grad_data + 3 * grid_num, loss_grad_data, - input_data + 3 * grid_num, th_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcSCEGrad(input_grad_data + 4 * grid_num, loss_grad_data, - input_data + 4 * grid_num, tconf_data, conf_mask_data, - conf_mask_data, n, an_num, grid_num, class_num, 1); - CalcSCEGrad(input_grad_data + 5 * grid_num, loss_grad_data, - input_data + 5 * grid_num, tclass_data, obj_mask_data, - obj_mask_data, n, an_num, grid_num, class_num, class_num); -} - -static int mask_index(std::vector mask, int val) { - for (int i = 0; i < mask.size(); i++) { +static int GetMaskIndex(std::vector mask, int val) { + for (size_t i = 0; i < mask.size(); i++) { if (mask[i] == val) { return i; } @@ -341,16 +71,9 @@ static inline T sigmoid(T x) { } template -static inline void sigmoid_arrray(T* arr, int len) { - for (int i = 0; i < len; i++) { - arr[i] = sigmoid(arr[i]); - } -} - -template -static inline Box get_yolo_box(const T* x, std::vector anchors, int i, - int j, int an_idx, int grid_size, - int input_size, int index, int stride) { +static inline Box GetYoloBox(const T* x, std::vector anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride) { Box b; b.x = (i + sigmoid(x[index])) / grid_size; b.y = (j + sigmoid(x[index + stride])) / grid_size; @@ -360,8 +83,7 @@ static inline Box get_yolo_box(const T* x, std::vector anchors, int i, } template -static inline Box get_gt_box(const T* gt, int batch, int max_boxes, - int idx) { +static inline Box GetGtBox(const T* gt, int batch, int max_boxes, int idx) { Box b; b.x = gt[(batch * max_boxes + idx) * 4]; b.y = gt[(batch * max_boxes + idx) * 4 + 1]; @@ -371,7 +93,7 @@ static inline Box get_gt_box(const T* gt, int batch, int max_boxes, } template -static inline T overlap(T c1, T w1, T c2, T w2) { +static inline T BoxOverlap(T c1, T w1, T c2, T w2) { T l1 = c1 - w1 / 2.0; T l2 = c2 - w2 / 2.0; T left = l1 > l2 ? l1 : l2; @@ -382,16 +104,16 @@ static inline T overlap(T c1, T w1, T c2, T w2) { } template -static inline T box_iou(Box b1, Box b2) { - T w = overlap(b1.x, b1.w, b2.x, b2.w); - T h = overlap(b1.y, b1.h, b2.y, b2.h); +static inline T CalcBoxIoU(Box b1, Box b2) { + T w = BoxOverlap(b1.x, b1.w, b2.x, b2.w); + T h = BoxOverlap(b1.y, b1.h, b2.y, b2.h); T inter_area = (w < 0 || h < 0) ? 0.0 : w * h; T union_area = b1.w * b1.h + b2.w * b2.h - inter_area; return inter_area / union_area; } -static inline int entry_index(int batch, int an_idx, int hw_idx, int an_num, - int an_stride, int stride, int entry) { +static inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride, int entry) { return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; } @@ -523,7 +245,7 @@ class Yolov3LossKernel : public framework::OpKernel { const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); - memset(loss_data, 0, n * sizeof(int)); + memset(loss_data, 0, loss->numel() * sizeof(T)); Tensor objness; int* objness_data = @@ -538,22 +260,18 @@ class Yolov3LossKernel : public framework::OpKernel { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { int box_idx = - entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = - get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, - input_size, box_idx, stride); + GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], + h, input_size, box_idx, stride); T best_iou = 0; - // int best_t = 0; for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); - T iou = box_iou(pred, gt); + T iou = CalcBoxIoU(pred, gt); if (iou > best_iou) { best_iou = iou; - // best_t = t; } } @@ -565,11 +283,10 @@ class Yolov3LossKernel : public framework::OpKernel { } } for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); Box gt_shift = gt; @@ -583,7 +300,7 @@ class Yolov3LossKernel : public framework::OpKernel { an_box.y = 0.0; an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); - float iou = box_iou(an_box, gt_shift); + float iou = CalcBoxIoU(an_box, gt_shift); // TO DO: iou > 0.5 ? if (iou > best_iou) { best_iou = iou; @@ -591,10 +308,10 @@ class Yolov3LossKernel : public framework::OpKernel { } } - int mask_idx = mask_index(anchor_mask, best_n); + int mask_idx = GetMaskIndex(anchor_mask, best_n); if (mask_idx >= 0) { - int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 0); + int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, box_idx, gi, gj, h, input_size, stride); @@ -602,8 +319,8 @@ class Yolov3LossKernel : public framework::OpKernel { objness_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; - int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 5); + int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, class_num, stride); } @@ -612,52 +329,6 @@ class Yolov3LossKernel : public framework::OpKernel { CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, mask_num, h, w, stride, an_stride); - - // Tensor conf_mask, obj_mask; - // Tensor tx, ty, tw, th, tweight, tconf, tclass; - // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - // - // math::SetConstant constant; - // constant(ctx.template device_context(), - // &conf_mask, static_cast(1.0)); - // constant(ctx.template device_context(), - // &obj_mask, static_cast(0.0)); - // constant(ctx.template device_context(), &tx, - // static_cast(0.0)); - // constant(ctx.template device_context(), &ty, - // static_cast(0.0)); - // constant(ctx.template device_context(), &tw, - // static_cast(0.0)); - // constant(ctx.template device_context(), &th, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tweight, static_cast(0.0)); - // constant(ctx.template device_context(), - // &tconf, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tclass, - // static_cast(0.0)); - // - // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, - // input_size, - // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, - // &tweight, - // &tconf, &tclass); - // - // T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); - // memset(loss_data, 0, n * sizeof(T)); - // CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, - // tclass, - // conf_mask, obj_mask); } }; @@ -706,22 +377,18 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { int box_idx = - entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = - get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, - input_size, box_idx, stride); + GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], + h, input_size, box_idx, stride); T best_iou = 0; - // int best_t = 0; for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); - T iou = box_iou(pred, gt); + T iou = CalcBoxIoU(pred, gt); if (iou > best_iou) { best_iou = iou; - // best_t = t; } } @@ -733,11 +400,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { } } for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); Box gt_shift = gt; @@ -751,7 +417,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { an_box.y = 0.0; an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); - float iou = box_iou(an_box, gt_shift); + float iou = CalcBoxIoU(an_box, gt_shift); // TO DO: iou > 0.5 ? if (iou > best_iou) { best_iou = iou; @@ -759,10 +425,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { } } - int mask_idx = mask_index(anchor_mask, best_n); + int mask_idx = GetMaskIndex(anchor_mask, best_n); if (mask_idx >= 0) { - int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 0); + int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], input_data, gt, anchors, best_n, box_idx, gi, gj, h, input_size, stride); @@ -771,8 +437,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { objness_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; - int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 5); + int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, label_idx, label, class_num, stride); } @@ -782,58 +448,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, input_data + 4 * stride, objness_data, n, mask_num, h, w, stride, an_stride); - - // const int n = input->dims()[0]; - // const int c = input->dims()[1]; - // const int h = input->dims()[2]; - // const int w = input->dims()[3]; - // const int an_num = anchors.size() / 2; - // - // Tensor conf_mask, obj_mask; - // Tensor tx, ty, tw, th, tweight, tconf, tclass; - // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - // - // math::SetConstant constant; - // constant(ctx.template device_context(), - // &conf_mask, static_cast(1.0)); - // constant(ctx.template device_context(), - // &obj_mask, static_cast(0.0)); - // constant(ctx.template device_context(), &tx, - // static_cast(0.0)); - // constant(ctx.template device_context(), &ty, - // static_cast(0.0)); - // constant(ctx.template device_context(), &tw, - // static_cast(0.0)); - // constant(ctx.template device_context(), &th, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tweight, static_cast(0.0)); - // constant(ctx.template device_context(), - // &tconf, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tclass, - // static_cast(0.0)); - // - // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, - // input_size, - // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, - // &tweight, - // &tconf, &tclass); - // - // T* input_grad_data = - // input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - // CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, - // th, - // tweight, tconf, tclass, conf_mask, obj_mask); } }; diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 3cada49647..188acea2b9 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -22,32 +22,6 @@ from op_test import OpTest from paddle.fluid import core -# def l1loss(x, y, weight): -# n = x.shape[0] -# x = x.reshape((n, -1)) -# y = y.reshape((n, -1)) -# weight = weight.reshape((n, -1)) -# return (np.abs(y - x) * weight).sum(axis=1) -# -# -# def mse(x, y, weight): -# n = x.shape[0] -# x = x.reshape((n, -1)) -# y = y.reshape((n, -1)) -# weight = weight.reshape((n, -1)) -# return ((y - x)**2 * weight).sum(axis=1) -# -# -# def sce(x, label, weight): -# n = x.shape[0] -# x = x.reshape((n, -1)) -# label = label.reshape((n, -1)) -# weight = weight.reshape((n, -1)) -# sigmoid_x = expit(x) -# term1 = label * np.log(sigmoid_x) -# term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) -# return ((-term1 - term2) * weight).sum(axis=1) - def l1loss(x, y): return abs(x - y) @@ -60,116 +34,6 @@ def sce(x, label): return -term1 - term2 -def box_iou(box1, box2): - b1_x1 = box1[0] - box1[2] / 2 - b1_x2 = box1[0] + box1[2] / 2 - b1_y1 = box1[1] - box1[3] / 2 - b1_y2 = box1[1] + box1[3] / 2 - b2_x1 = box2[0] - box2[2] / 2 - b2_x2 = box2[0] + box2[2] / 2 - b2_y1 = box2[1] - box2[3] / 2 - b2_y2 = box2[1] + box2[3] / 2 - - b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) - b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) - - inter_rect_x1 = max(b1_x1, b2_x1) - inter_rect_y1 = max(b1_y1, b2_y1) - inter_rect_x2 = min(b1_x2, b2_x2) - inter_rect_y2 = min(b1_y2, b2_y2) - inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max( - inter_rect_y2 - inter_rect_y1, 0) - - return inter_area / (b1_area + b2_area + inter_area) - - -def build_target(gtboxes, gtlabel, attrs, grid_size): - n, b, _ = gtboxes.shape - ignore_thresh = attrs["ignore_thresh"] - anchors = attrs["anchors"] - class_num = attrs["class_num"] - input_size = attrs["input_size"] - an_num = len(anchors) // 2 - conf_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') - obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tweight = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tcls = np.zeros( - (n, an_num, grid_size, grid_size, class_num)).astype('float32') - - for i in range(n): - for j in range(b): - if gtboxes[i, j, :].sum() == 0: - continue - - gt_label = gtlabel[i, j] - gx = gtboxes[i, j, 0] * grid_size - gy = gtboxes[i, j, 1] * grid_size - gw = gtboxes[i, j, 2] * input_size - gh = gtboxes[i, j, 3] * input_size - - gi = int(gx) - gj = int(gy) - - gtbox = [0, 0, gw, gh] - max_iou = 0 - for k in range(an_num): - anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]] - iou = box_iou(gtbox, anchor_box) - if iou > max_iou: - max_iou = iou - best_an_index = k - if iou > ignore_thresh: - conf_mask[i, best_an_index, gj, gi] = 0 - - conf_mask[i, best_an_index, gj, gi] = 1 - obj_mask[i, best_an_index, gj, gi] = 1 - tx[i, best_an_index, gj, gi] = gx - gi - ty[i, best_an_index, gj, gi] = gy - gj - tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * - best_an_index]) - th[i, best_an_index, gj, gi] = np.log( - gh / anchors[2 * best_an_index + 1]) - tweight[i, best_an_index, gj, gi] = 2.0 - gtboxes[ - i, j, 2] * gtboxes[i, j, 3] - tconf[i, best_an_index, gj, gi] = 1 - tcls[i, best_an_index, gj, gi, gt_label] = 1 - - return (tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask) - - -def YoloV3Loss(x, gtbox, gtlabel, attrs): - n, c, h, w = x.shape - an_num = len(attrs['anchors']) // 2 - class_num = attrs["class_num"] - x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) - pred_x = x[:, :, :, :, 0] - pred_y = x[:, :, :, :, 1] - pred_w = x[:, :, :, :, 2] - pred_h = x[:, :, :, :, 3] - pred_conf = x[:, :, :, :, 4] - pred_cls = x[:, :, :, :, 5:] - - tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask = build_target( - gtbox, gtlabel, attrs, x.shape[2]) - - obj_weight = obj_mask * tweight - obj_mask_expand = np.tile( - np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) - loss_x = sce(pred_x, tx, obj_weight) - loss_y = sce(pred_y, ty, obj_weight) - loss_w = l1loss(pred_w, tw, obj_weight) - loss_h = l1loss(pred_h, th, obj_weight) - loss_obj = sce(pred_conf, tconf, conf_mask) - loss_class = sce(pred_cls, tcls, obj_mask_expand) - - return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class - - def sigmoid(x): return 1.0 / (1.0 + np.exp(-1.0 * x)) @@ -291,8 +155,10 @@ class TestYolov3LossOp(OpTest): self.op_type = 'yolov3_loss' x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') - gtlabel = np.random.randint(0, self.class_num, - self.gtbox_shape[:2]).astype('int32') + gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]) + gtmask = np.random.randint(0, 2, self.gtbox_shape[:2]) + gtbox = gtbox * gtmask[:, :, np.newaxis] + gtlabel = gtlabel * gtmask self.attrs = { "anchors": self.anchors, @@ -302,7 +168,11 @@ class TestYolov3LossOp(OpTest): "downsample": self.downsample, } - self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} + self.inputs = { + 'X': x, + 'GTBox': gtbox.astype('float32'), + 'GTLabel': gtlabel.astype('int32') + } self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} def test_check_output(self): From 32d533c2cd9aa6dcd0d3cbe9b9685f97d378337e Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 28 Dec 2018 17:49:02 +0800 Subject: [PATCH 11/24] cache obj_mask and gt_match_mask. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 23 ++++ paddle/fluid/operators/yolov3_loss_op.h | 110 +++++------------- python/paddle/fluid/layers/detection.py | 9 +- .../tests/unittests/test_yolov3_loss_op.py | 16 ++- 4 files changed, 76 insertions(+), 82 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 8c46e341d6..5b777f0448 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -29,6 +29,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTLabel) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("ObjectnessMask"), + "Output(ObjectnessMask) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("GTMatchMask"), + "Output(GTMatchMask) of Yolov3LossOp should not be null."); auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); @@ -68,6 +73,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel { std::vector dim_out({dim_x[0]}); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); + + std::vector dim_obj_mask({dim_x[0], mask_num, dim_x[2], dim_x[3]}); + ctx->SetOutputDim("ObjectnessMask", framework::make_ddim(dim_obj_mask)); + + std::vector dim_gt_match_mask({dim_gtbox[0], dim_gtbox[1]}); + ctx->SetOutputDim("GTMatchMask", framework::make_ddim(dim_gt_match_mask)); } protected: @@ -103,6 +114,16 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); + AddOutput("ObjectnessMask", + "This is an intermediate tensor with shape of [N, M, H, W], " + "M is the number of anchor masks. This parameter caches the " + "mask for calculate objectness loss in gradient kernel.") + .AsIntermediate(); + AddOutput("GTMatchMask", + "This is an intermediate tensor with shape if [N, B], " + "B is the max box number of GT boxes. This parameter caches " + "matched mask index of each GT boxes for gradient calculate.") + .AsIntermediate(); AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", @@ -208,6 +229,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + op->SetInput("ObjectnessMask", Output("ObjectnessMask")); + op->SetInput("GTMatchMask", Output("GTMatchMask")); op->SetAttrMap(Attrs()); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 12499befca..85d93cf96f 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -227,6 +227,8 @@ class Yolov3LossKernel : public framework::OpKernel { auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); auto* loss = ctx.Output("Loss"); + auto* objness_mask = ctx.Output("ObjectnessMask"); + auto* gt_match_mask = ctx.Output("GTMatchMask"); auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); @@ -241,19 +243,19 @@ class Yolov3LossKernel : public framework::OpKernel { const int b = gt_box->dims()[1]; int input_size = downsample * h; + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); - - Tensor objness; - int* objness_data = - objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(objness_data, 0, objness.numel() * sizeof(int)); - - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; + int* obj_mask_data = + objness_mask->mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(obj_mask_data, 0, objness_mask->numel() * sizeof(int)); + int* gt_match_mask_data = + gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); for (int i = 0; i < n; i++) { for (int j = 0; j < mask_num; j++) { @@ -277,7 +279,7 @@ class Yolov3LossKernel : public framework::OpKernel { if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; - objness_data[obj_idx] = -1; + obj_mask_data[obj_idx] = -1; } } } @@ -285,6 +287,7 @@ class Yolov3LossKernel : public framework::OpKernel { for (int t = 0; t < b; t++) { Box gt = GetGtBox(gt_box_data, i, b, t); if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + gt_match_mask_data[i * b + t] = -1; continue; } int gi = static_cast(gt.x * w); @@ -309,6 +312,7 @@ class Yolov3LossKernel : public framework::OpKernel { } int mask_idx = GetMaskIndex(anchor_mask, best_n); + gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); @@ -316,7 +320,7 @@ class Yolov3LossKernel : public framework::OpKernel { box_idx, gi, gj, h, input_size, stride); int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - objness_data[obj_idx] = 1; + obj_mask_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, @@ -327,7 +331,7 @@ class Yolov3LossKernel : public framework::OpKernel { } } - CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, + CalcObjnessLoss(loss_data, input_data + 4 * stride, obj_mask_data, n, mask_num, h, w, stride, an_stride); } }; @@ -341,64 +345,35 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* gt_label = ctx.Input("GTLabel"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + auto* objness_mask = ctx.Input("ObjectnessMask"); + auto* gt_match_mask = ctx.Input("GTMatchMask"); auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - const int an_num = anchors.size() / 2; + const int n = input_grad->dims()[0]; + const int c = input_grad->dims()[1]; + const int h = input_grad->dims()[2]; + const int w = input_grad->dims()[3]; const int mask_num = anchor_mask.size(); - const int b = gt_box->dims()[1]; + const int b = gt_match_mask->dims()[1]; int input_size = downsample * h; + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); const T* loss_grad_data = loss_grad->data(); + const int* obj_mask_data = objness_mask->data(); + const int* gt_match_mask_data = gt_match_mask->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); - Tensor objness; - int* objness_data = - objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(objness_data, 0, objness.numel() * sizeof(int)); - - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; - for (int i = 0; i < n; i++) { - for (int j = 0; j < mask_num; j++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - int box_idx = - GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], - h, input_size, box_idx, stride); - T best_iou = 0; - for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { - continue; - } - T iou = CalcBoxIoU(pred, gt); - if (iou > best_iou) { - best_iou = iou; - } - } - - if (best_iou > ignore_thresh) { - int obj_idx = (i * mask_num + j) * stride + k * w + l; - objness_data[obj_idx] = -1; - } - } - } - } for (int t = 0; t < b; t++) { Box gt = GetGtBox(gt_box_data, i, b, t); if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { @@ -406,35 +381,14 @@ class Yolov3LossGradKernel : public framework::OpKernel { } int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); - Box gt_shift = gt; - gt_shift.x = 0.0; - gt_shift.y = 0.0; - T best_iou = 0.0; - int best_n = 0; - for (int an_idx = 0; an_idx < an_num; an_idx++) { - Box an_box; - an_box.x = 0.0; - an_box.y = 0.0; - an_box.w = anchors[2 * an_idx] / static_cast(input_size); - an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); - float iou = CalcBoxIoU(an_box, gt_shift); - // TO DO: iou > 0.5 ? - if (iou > best_iou) { - best_iou = iou; - best_n = an_idx; - } - } - int mask_idx = GetMaskIndex(anchor_mask, best_n); + int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], - input_data, gt, anchors, best_n, box_idx, - gi, gj, h, input_size, stride); - - int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - objness_data[obj_idx] = 1; + CalcBoxLocationLossGrad( + input_grad_data, loss_grad_data[i], input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, @@ -446,7 +400,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { } CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, - input_data + 4 * stride, objness_data, n, mask_num, + input_data + 4 * stride, obj_mask_data, n, mask_num, h, w, stride, an_stride); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 542162b7f4..90d112aa01 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -483,6 +483,9 @@ def yolov3_loss(x, loss = helper.create_variable( name=name, dtype=x.dtype, persistable=False) + objectness_mask = helper.create_variable_for_type_inference(dtype='int32') + gt_match_mask = helper.create_variable_for_type_inference(dtype='int32') + attrs = { "anchors": anchors, "anchor_mask": anchor_mask, @@ -496,7 +499,11 @@ def yolov3_loss(x, inputs={"X": x, "GTBox": gtbox, "GTLabel": gtlabel}, - outputs={'Loss': loss}, + outputs={ + 'Loss': loss, + 'ObjectnessMask': objectness_mask, + 'GTMatchMask': gt_match_mask + }, attrs=attrs) return loss diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 188acea2b9..904bee00c1 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -116,13 +116,17 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1)) ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes) iou_matches = np.argmax(ious, axis=-1) + gt_matches = iou_matches.copy() for i in range(n): for j in range(b): if gtbox[i, j, 2:].sum() == 0: + gt_matches[i, j] = -1 continue if iou_matches[i, j] not in anchor_mask: + gt_matches[i, j] = -1 continue an_idx = anchor_mask.index(iou_matches[i, j]) + gt_matches[i, j] = an_idx gi = int(gtbox[i, j, 0] * w) gj = int(gtbox[i, j, 1] * h) @@ -146,7 +150,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): if objness[i, j] >= 0: loss[i] += sce(pred_obj[i, j], objness[i, j]) - return loss + return (loss, objness.reshape((n, mask_num, h, w)).astype('int32'), \ + gt_matches.astype('int32')) class TestYolov3LossOp(OpTest): @@ -173,11 +178,16 @@ class TestYolov3LossOp(OpTest): 'GTBox': gtbox.astype('float32'), 'GTLabel': gtlabel.astype('int32') } - self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) + self.outputs = { + 'Loss': loss, + 'ObjectnessMask': objness, + "GTMatchMask": gt_matches + } def test_check_output(self): place = core.CPUPlace() - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=2e-3) def test_check_grad_ignore_gtbox(self): place = core.CPUPlace() From cc01db6029c84b5e059d355b95dd73d18894594f Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 28 Dec 2018 20:06:52 +0800 Subject: [PATCH 12/24] calc valid gt before loss calc. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 41 ++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 85d93cf96f..301e2f4033 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -219,6 +219,22 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, } } +template +static void inline GtValid(bool* valid, const T* gtbox, const int n, + const int b) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < b; j++) { + if (LessEqualZero(gtbox[j * 4 + 2]) || LessEqualZero(gtbox[j * 4 + 3])) { + valid[j] = false; + } else { + valid[j] = true; + } + } + valid += b; + gtbox += b * 4; + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -257,20 +273,28 @@ class Yolov3LossKernel : public framework::OpKernel { int* gt_match_mask_data = gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); + // calc valid gt box mask, avoid calc duplicately in following code + Tensor gt_valid_mask; + bool* gt_valid_mask_data = + gt_valid_mask.mutable_data({n, b}, ctx.GetPlace()); + GtValid(gt_valid_mask_data, gt_box_data, n, b); + for (int i = 0; i < n; i++) { for (int j = 0; j < mask_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { + // each predict box find a best match gt box, if overlap is bigger + // then ignore_thresh, ignore the objectness loss. int box_idx = GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], h, input_size, box_idx, stride); T best_iou = 0; for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + if (!gt_valid_mask_data[i * b + t]) { continue; } + Box gt = GetGtBox(gt_box_data, i, b, t); T iou = CalcBoxIoU(pred, gt); if (iou > best_iou) { best_iou = iou; @@ -281,15 +305,18 @@ class Yolov3LossKernel : public framework::OpKernel { int obj_idx = (i * mask_num + j) * stride + k * w + l; obj_mask_data[obj_idx] = -1; } + // TODO(dengkaipeng): all losses should be calculated if best IoU + // is bigger then truth thresh should be calculated here, but + // currently, truth thresh is an unreachable value as 1.0. } } } for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + if (!gt_valid_mask_data[i * b + t]) { gt_match_mask_data[i * b + t] = -1; continue; } + Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); Box gt_shift = gt; @@ -297,6 +324,9 @@ class Yolov3LossKernel : public framework::OpKernel { gt_shift.y = 0.0; T best_iou = 0.0; int best_n = 0; + // each gt box find a best match anchor box as positive sample, + // for positive sample, all losses should be calculated, and for + // other samples, only objectness loss is required. for (int an_idx = 0; an_idx < an_num; an_idx++) { Box an_box; an_box.x = 0.0; @@ -304,7 +334,8 @@ class Yolov3LossKernel : public framework::OpKernel { an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); float iou = CalcBoxIoU(an_box, gt_shift); - // TO DO: iou > 0.5 ? + // TODO(dengkaipeng): In paper, objectness loss is ignore when + // best IoU > 0.5, but darknet code didn't implement this. if (iou > best_iou) { best_iou = iou; best_n = an_idx; From 3c08f620c248c506116dbb5a58224de9743bb048 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 3 Jan 2019 11:16:29 +0800 Subject: [PATCH 13/24] add label smooth. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 19 ++++++++++--------- .../tests/unittests/test_yolov3_loss_op.py | 6 +++++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 301e2f4033..34119b1a02 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -159,7 +159,9 @@ static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { - loss[0] += SCE(input[index + i * stride], (i == label) ? 1.0 : 0.0); + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); } } @@ -169,8 +171,10 @@ static inline void CalcLabelLossGrad(T* input_grad, const T loss, const int label, const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; input_grad[index + i * stride] = - SCEGrad(input[index + i * stride], (i == label) ? 1.0 : 0.0) * loss; + SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; } } @@ -406,15 +410,12 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int i = 0; i < n; i++) { for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { - continue; - } - int gi = static_cast(gt.x * w); - int gj = static_cast(gt.y * h); - int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { + Box gt = GetGtBox(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); CalcBoxLocationLossGrad( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 904bee00c1..27fb92c589 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -86,6 +86,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:], + np.ones_like(x[:, :, :, :, 5:]) * 1.0 / + class_num) + mask_anchors = [] for m in anchor_mask: mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) @@ -207,7 +211,7 @@ class TestYolov3LossOp(OpTest): self.ignore_thresh = 0.7 self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 10, 4) + self.gtbox_shape = (3, 5, 4) if __name__ == "__main__": From 8218e30176c6bdaccd11cd0141c6f47878233b54 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 4 Jan 2019 11:40:08 +0800 Subject: [PATCH 14/24] add gtscore. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 20 +++++++++++++++-- paddle/fluid/operators/yolov3_loss_op.h | 22 ++++++++++++------- python/paddle/fluid/layers/detection.py | 17 ++++++++++---- .../tests/unittests/test_yolov3_loss_op.py | 19 +++++++++------- 5 files changed, 57 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6c6ac9c7ea..bf0916a076 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 5b777f0448..c146035f9d 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -27,6 +27,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("GTLabel"), "Input(GTLabel) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTScore"), + "Input(GTScore) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); PADDLE_ENFORCE( @@ -38,6 +40,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtlabel = ctx->GetInputDim("GTLabel"); + auto dim_gtscore = ctx->GetInputDim("GTScore"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); @@ -54,11 +57,17 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2, - "Input(GTBox) should be a 2-D tensor"); + "Input(GTLabel) should be a 2-D tensor"); PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0], "Input(GTBox) and Input(GTLabel) dim[0] should be same"); PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], "Input(GTBox) and Input(GTLabel) dim[1] should be same"); + PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2, + "Input(GTScore) should be a 2-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtscore[0], dim_gtbox[0], + "Input(GTBox) and Input(GTScore) dim[0] should be same"); + PADDLE_ENFORCE_EQ(dim_gtscore[1], dim_gtbox[1], + "Input(GTBox) and Input(GTScore) dim[1] should be same"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -109,8 +118,13 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("GTLabel", "The input tensor of ground truth label, " "This is a 2-D tensor with shape of [N, max_box_num], " - "and each element shoudl be an integer to indicate the " + "and each element should be an integer to indicate the " "box class id."); + AddInput("GTScore", + "The score of GTLabel, This is a 2-D tensor in same shape " + "GTLabel, and score values should in range (0, 1). This " + "input is for GTLabel score can be not 1.0 in image mixup " + "augmentation."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); @@ -228,6 +242,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("X", Input("X")); op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); + op->SetInput("GTScore", Input("GTScore")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("GTMatchMask", Output("GTMatchMask")); @@ -237,6 +252,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {}); + op->SetOutput(framework::GradVarName("GTScore"), {}); return std::unique_ptr(op); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 34119b1a02..c4095b8ca5 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -156,25 +156,25 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, - const int label, const int class_num, - const int stride) { + const int label, const T score, + const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); + loss[0] += SCE(pred, (i == label) ? score : 0.0); } } template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, - const int label, const int class_num, - const int stride) { + const int label, const T score, + const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] : 1.0 / class_num; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; + SCEGrad(pred, (i == label) ? score : 0.0) * loss; } } @@ -246,6 +246,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* loss = ctx.Output("Loss"); auto* objness_mask = ctx.Output("ObjectnessMask"); auto* gt_match_mask = ctx.Output("GTMatchMask"); @@ -269,6 +270,7 @@ class Yolov3LossKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); int* obj_mask_data = @@ -358,9 +360,10 @@ class Yolov3LossKernel : public framework::OpKernel { obj_mask_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; + T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); - CalcLabelLoss(loss_data + i, input_data, label_idx, label, + CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, class_num, stride); } } @@ -378,6 +381,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -401,6 +405,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); const int* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); @@ -423,10 +428,11 @@ class Yolov3LossGradKernel : public framework::OpKernel { anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; + T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, class_num, stride); + label_idx, label, score, class_num, stride); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 90d112aa01..10573cc4c6 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -412,6 +412,7 @@ def polygon_box_transform(input, name=None): def yolov3_loss(x, gtbox, gtlabel, + gtscore, anchors, anchor_mask, class_num, @@ -428,8 +429,10 @@ def yolov3_loss(x, and x, y, w, h should be relative value of input image. N is the batch number and B is the max box number in an image. - gtlabel (Variable): class id of ground truth boxes, shoud be ins shape + gtlabel (Variable): class id of ground truth boxes, shoud be in shape of [N, B]. + gtscore (Variable): score of gtlabel, should be in same shape with gtlabel + and score value in range (0, 1). anchors (list|tuple): ${anchors_comment} anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} @@ -444,6 +447,7 @@ def yolov3_loss(x, TypeError: Input x of yolov3_loss must be Variable TypeError: Input gtbox of yolov3_loss must be Variable" TypeError: Input gtlabel of yolov3_loss must be Variable" + TypeError: Input gtscore of yolov3_loss must be Variable" TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number @@ -467,6 +471,8 @@ def yolov3_loss(x, raise TypeError("Input gtbox of yolov3_loss must be Variable") if not isinstance(gtlabel, Variable): raise TypeError("Input gtlabel of yolov3_loss must be Variable") + if not isinstance(gtscore, Variable): + raise TypeError("Input gtscore of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): @@ -496,9 +502,12 @@ def yolov3_loss(x, helper.append_op( type='yolov3_loss', - inputs={"X": x, - "GTBox": gtbox, - "GTLabel": gtlabel}, + inputs={ + "X": x, + "GTBox": gtbox, + "GTLabel": gtlabel, + "GTScore": gtscore + }, outputs={ 'Loss': loss, 'ObjectnessMask': objectness_mask, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 27fb92c589..c65570d7c1 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2): return inter_area / union -def YOLOv3Loss(x, gtbox, gtlabel, attrs): +def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): n, c, h, w = x.shape b = gtbox.shape[1] anchors = attrs['anchors'] @@ -148,7 +148,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): for label_idx in range(class_num): loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], - int(label_idx == gtlabel[i, j])) + int(label_idx == gtlabel[i, j]) * gtscore[i, j]) for j in range(mask_num * h * w): if objness[i, j] >= 0: @@ -165,6 +165,7 @@ class TestYolov3LossOp(OpTest): x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]) + gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32') gtmask = np.random.randint(0, 2, self.gtbox_shape[:2]) gtbox = gtbox * gtmask[:, :, np.newaxis] gtlabel = gtlabel * gtmask @@ -180,9 +181,11 @@ class TestYolov3LossOp(OpTest): self.inputs = { 'X': x, 'GTBox': gtbox.astype('float32'), - 'GTLabel': gtlabel.astype('int32') + 'GTLabel': gtlabel.astype('int32'), + 'GTScore': gtscore.astype('float32') } - loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore, + self.attrs) self.outputs = { 'Loss': loss, 'ObjectnessMask': objness, @@ -198,8 +201,8 @@ class TestYolov3LossOp(OpTest): self.check_grad_with_place( place, ['X'], 'Loss', - no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.15) + no_grad_set=set(["GTBox", "GTLabel", "GTScore"]), + max_relative_error=0.2) def initTestCase(self): self.anchors = [ @@ -207,11 +210,11 @@ class TestYolov3LossOp(OpTest): 373, 326 ] self.anchor_mask = [0, 1, 2] - self.class_num = 5 + self.class_num = 10 self.ignore_thresh = 0.7 self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 5, 4) + self.gtbox_shape = (3, 10, 4) if __name__ == "__main__": From 2b89f590559bc76d6f821789edee42cf56a68582 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 10 Jan 2019 06:57:28 +0000 Subject: [PATCH 15/24] add attr use_label_smooth test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 3 ++ paddle/fluid/operators/yolov3_loss_op.h | 46 +++++++++++++------ python/paddle/fluid/layers/detection.py | 6 +++ .../tests/unittests/test_yolov3_loss_op.py | 8 ++++ 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bf0916a076..d773c2518c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'label_smooth', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index c146035f9d..0c5426728b 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -46,6 +46,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); int mask_num = anchor_mask.size(); auto class_num = ctx->Attrs().Get("class_num"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); @@ -156,6 +157,8 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss.") .SetDefault(0.7); + AddAttr("use_label_smooth", "bool,default True", "use label smooth") + .SetDefault(true); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index c4095b8ca5..f601651f06 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -157,11 +157,19 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const T score, - const int class_num, const int stride) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? score : 0.0); + const int class_num, const int stride, + const bool use_label_smooth) { + if (use_label_smooth) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + loss[0] += SCE(pred, (i == label) ? score : 0.0); + } + } else { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + loss[0] += SCE(pred, (i == label) ? score : 0.0); + } } } @@ -169,12 +177,21 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const T score, - const int class_num, const int stride) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? score : 0.0) * loss; + const int class_num, const int stride, + const bool use_label_smooth) { + if (use_label_smooth) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? score : 0.0) * loss; + } + } else { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? score : 0.0) * loss; + } } } @@ -255,6 +272,7 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -364,7 +382,7 @@ class Yolov3LossKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, - class_num, stride); + class_num, stride, use_label_smooth); } } } @@ -390,6 +408,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); int downsample = ctx.Attr("downsample"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -432,7 +451,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, score, class_num, stride); + label_idx, label, score, class_num, stride, + use_label_smooth); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 10573cc4c6..e984576ffe 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -418,6 +418,7 @@ def yolov3_loss(x, class_num, ignore_thresh, downsample, + use_label_smooth=True, name=None): """ ${comment} @@ -438,6 +439,7 @@ def yolov3_loss(x, class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} downsample (int): ${downsample_comment} + use_label_smooth(bool): ${use_label_smooth_comment} name (string): the name of yolov3 loss Returns: @@ -451,6 +453,7 @@ def yolov3_loss(x, TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number + TypeError: Attr use_label_smooth of yolov3_loss must be a bool value Examples: .. code-block:: python @@ -479,6 +482,8 @@ def yolov3_loss(x, raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(class_num, int): + raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( "Attr ignore_thresh of yolov3_loss must be a float number") @@ -498,6 +503,7 @@ def yolov3_loss(x, "class_num": class_num, "ignore_thresh": ignore_thresh, "downsample": downsample, + "use_label_smooth": use_label_smooth } helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index c65570d7c1..1746a1da1d 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -76,6 +76,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] downsample = attrs['downsample'] + #use_label_smooth = attrs['use_label_smooth'] input_size = downsample * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') @@ -176,6 +177,7 @@ class TestYolov3LossOp(OpTest): "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, "downsample": self.downsample, + "use_label_smooth": self.use_label_smooth, } self.inputs = { @@ -215,6 +217,12 @@ class TestYolov3LossOp(OpTest): self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 10, 4) + self.use_label_smooth = True + + +class TestYolov3LossWithLabelSmooth(TestYolov3LossOp): + def set_label_smooth(self): + self.use_label_smooth = True if __name__ == "__main__": From 20200e126d0bfcc9e98e278764768f38ff1831e8 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 10 Jan 2019 07:15:35 +0000 Subject: [PATCH 16/24] fix some typo test=develop --- python/paddle/fluid/layers/detection.py | 2 +- python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index e984576ffe..febfc8e127 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -482,7 +482,7 @@ def yolov3_loss(x, raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(class_num, int): + if not isinstance(use_label_smooth, int): raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 1746a1da1d..79c953bbd1 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -76,7 +76,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] downsample = attrs['downsample'] - #use_label_smooth = attrs['use_label_smooth'] + use_label_smooth = attrs['use_label_smooth'] input_size = downsample * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') From c945ffa7f8949277e1053c430918147d9e908303 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 14 Jan 2019 21:16:06 +0800 Subject: [PATCH 17/24] fix label_smooth and mixup score --- paddle/fluid/operators/yolov3_loss_op.h | 98 +++++++++---------- .../tests/unittests/test_yolov3_loss_op.py | 17 ++-- 2 files changed, 55 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index f601651f06..5cb48b7cdf 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -156,47 +156,29 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, - const int label, const T score, - const int class_num, const int stride, - const bool use_label_smooth) { - if (use_label_smooth) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? score : 0.0); - } - } else { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? score : 0.0); - } + const int label, const int class_num, + const int stride, const T pos, const T neg) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + loss[0] += SCE(pred, (i == label) ? pos : neg); } } template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, - const int label, const T score, - const int class_num, const int stride, - const bool use_label_smooth) { - if (use_label_smooth) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? score : 0.0) * loss; - } - } else { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride]; - input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? score : 0.0) * loss; - } + const int label, const int class_num, + const int stride, const T pos, + const T neg) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? pos : neg) * loss; } } template -static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, +static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, const int n, const int an_num, const int h, const int w, const int stride, const int an_stride) { @@ -204,9 +186,9 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - int obj = objness[k * w + l]; - if (obj >= 0) { - loss[i] += SCE(input[k * w + l], static_cast(obj)); + T obj = objness[k * w + l]; + if (obj > -0.5) { + loss[i] += SCE(input[k * w + l], obj); } } } @@ -218,7 +200,7 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, template static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, - const T* input, const int* objness, + const T* input, const T* objness, const int n, const int an_num, const int h, const int w, const int stride, const int an_stride) { @@ -226,10 +208,9 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - int obj = objness[k * w + l]; - if (obj >= 0) { - input_grad[k * w + l] = - SCEGrad(input[k * w + l], static_cast(obj)) * loss[i]; + T obj = objness[k * w + l]; + if (obj > -0.5) { + input_grad[k * w + l] = SCEGrad(input[k * w + l], obj) * loss[i]; } } } @@ -285,15 +266,22 @@ class Yolov3LossKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; + T label_pos = 1.0; + T label_neg = 0.0; + if (use_label_smooth) { + label_pos = 1.0 - 1.0 / static_cast(class_num); + label_neg = 1.0 / static_cast(class_num); + } + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); const T* gt_score_data = gt_score->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); - int* obj_mask_data = - objness_mask->mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(obj_mask_data, 0, objness_mask->numel() * sizeof(int)); + T* obj_mask_data = + objness_mask->mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(obj_mask_data, 0, objness_mask->numel() * sizeof(T)); int* gt_match_mask_data = gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); @@ -327,7 +315,7 @@ class Yolov3LossKernel : public framework::OpKernel { if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; - obj_mask_data[obj_idx] = -1; + obj_mask_data[obj_idx] = static_cast(-1.0); } // TODO(dengkaipeng): all losses should be calculated if best IoU // is bigger then truth thresh should be calculated here, but @@ -374,15 +362,15 @@ class Yolov3LossKernel : public framework::OpKernel { CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, box_idx, gi, gj, h, input_size, stride); + T score = gt_score_data[i * b + t]; int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - obj_mask_data[obj_idx] = 1; + obj_mask_data[obj_idx] = score; int label = gt_label_data[i * b + t]; - T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); - CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, - class_num, stride, use_label_smooth); + CalcLabelLoss(loss_data + i, input_data, label_idx, label, + class_num, stride, label_pos, label_neg); } } } @@ -399,7 +387,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); - auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -421,12 +408,18 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; + T label_pos = 1.0; + T label_neg = 0.0; + if (use_label_smooth) { + label_pos = 1.0 - 1.0 / static_cast(class_num); + label_neg = 1.0 / static_cast(class_num); + } + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); - const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); - const int* obj_mask_data = objness_mask->data(); + const T* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); @@ -447,12 +440,11 @@ class Yolov3LossGradKernel : public framework::OpKernel { anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; - T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, score, class_num, stride, - use_label_smooth); + label_idx, label, class_num, stride, label_pos, + label_neg); } } } diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 79c953bbd1..426a64f7a2 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -81,6 +81,9 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') + label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0 + label_neg = 1.0 / class_num if use_label_smooth else 0.0 + pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) @@ -103,7 +106,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): pred_box = pred_box.reshape((n, -1, 4)) pred_obj = x[:, :, :, :, 4].reshape((n, -1)) - objness = np.zeros(pred_box.shape[:2]) + objness = np.zeros(pred_box.shape[:2]).astype('float32') ious = batch_xywh_box_iou(pred_box, gtbox) ious_max = np.max(ious, axis=-1) objness = np.where(ious_max > ignore_thresh, -np.ones_like(objness), @@ -145,17 +148,17 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale - objness[i, an_idx * h * w + gj * w + gi] = 1 + objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j] for label_idx in range(class_num): - loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], - int(label_idx == gtlabel[i, j]) * gtscore[i, j]) + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos + if label_idx == gtlabel[i, j] else label_neg) for j in range(mask_num * h * w): if objness[i, j] >= 0: loss[i] += sce(pred_obj[i, j], objness[i, j]) - return (loss, objness.reshape((n, mask_num, h, w)).astype('int32'), \ + return (loss, objness.reshape((n, mask_num, h, w)).astype('float32'), \ gt_matches.astype('int32')) @@ -220,9 +223,9 @@ class TestYolov3LossOp(OpTest): self.use_label_smooth = True -class TestYolov3LossWithLabelSmooth(TestYolov3LossOp): +class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): def set_label_smooth(self): - self.use_label_smooth = True + self.use_label_smooth = False if __name__ == "__main__": From af124dcdf6891390202fffb7c30daf70aa3c8659 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 14 Jan 2019 21:30:25 +0800 Subject: [PATCH 18/24] fix API error --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.h | 55 ++++++++++++------- python/paddle/fluid/layers/detection.py | 2 +- .../tests/unittests/test_yolov3_loss_op.py | 11 ++-- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d773c2518c..e71e494f9d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'label_smooth', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(True, None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 5cb48b7cdf..de01a01a4f 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -121,13 +121,13 @@ template static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, int grid_size, - int input_size, int stride) { + int input_size, int stride, T score) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = 2.0 - gt.w * gt.h; + T scale = (2.0 - gt.w * gt.h) * score; loss[0] += SCE(input[box_idx], tx) * scale; loss[0] += SCE(input[box_idx + stride], ty) * scale; loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; @@ -138,13 +138,14 @@ template static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, - int grid_size, int input_size, int stride) { + int grid_size, int input_size, int stride, + T score) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = 2.0 - gt.w * gt.h; + T scale = (2.0 - gt.w * gt.h) * score; input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; input_grad[box_idx + stride] = SCEGrad(input[box_idx + stride], ty) * scale * loss; @@ -157,10 +158,11 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, const T neg) { + const int stride, const T pos, const T neg, + T score) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? pos : neg); + loss[0] += SCE(pred, (i == label) ? pos : neg) * score; } } @@ -168,12 +170,12 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, - const T neg) { + const int stride, const T pos, const T neg, + T score) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? pos : neg) * loss; + SCEGrad(pred, (i == label) ? pos : neg) * score * loss; } } @@ -187,8 +189,12 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; - if (obj > -0.5) { - loss[i] += SCE(input[k * w + l], obj); + if (obj > 1e-5) { + // positive sample: obj = mixup score + loss[i] += SCE(input[k * w + l], 1.0) * obj; + } else if (obj > -0.5) { + // negetive sample: obj = 0 + loss[i] += SCE(input[k * w + l], 0.0); } } } @@ -209,8 +215,11 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; - if (obj > -0.5) { - input_grad[k * w + l] = SCEGrad(input[k * w + l], obj) * loss[i]; + if (obj > 1e-5) { + input_grad[k * w + l] = + SCEGrad(input[k * w + l], 1.0) * obj * loss[i]; + } else if (obj > -0.5) { + input_grad[k * w + l] = SCEGrad(input[k * w + l], 0.0) * loss[i]; } } } @@ -315,7 +324,7 @@ class Yolov3LossKernel : public framework::OpKernel { if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; - obj_mask_data[obj_idx] = static_cast(-1.0); + obj_mask_data[obj_idx] = static_cast(-1); } // TODO(dengkaipeng): all losses should be calculated if best IoU // is bigger then truth thresh should be calculated here, but @@ -357,12 +366,12 @@ class Yolov3LossKernel : public framework::OpKernel { int mask_idx = GetMaskIndex(anchor_mask, best_n); gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { + T score = gt_score_data[i * b + t]; int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, - box_idx, gi, gj, h, input_size, stride); + box_idx, gi, gj, h, input_size, stride, score); - T score = gt_score_data[i * b + t]; int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; obj_mask_data[obj_idx] = score; @@ -370,7 +379,7 @@ class Yolov3LossKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, - class_num, stride, label_pos, label_neg); + class_num, stride, label_pos, label_neg, score); } } } @@ -387,6 +396,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -418,6 +428,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); const T* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); @@ -429,22 +440,24 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int t = 0; t < b; t++) { int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { + T score = gt_score_data[i * b + t]; Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad( - input_grad_data, loss_grad_data[i], input_data, gt, anchors, - anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); + CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], + input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, + input_size, stride, score); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, label_idx, label, class_num, stride, label_pos, - label_neg); + label_neg, score); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index febfc8e127..07df601697 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -482,7 +482,7 @@ def yolov3_loss(x, raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(use_label_smooth, int): + if not isinstance(use_label_smooth, bool): raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 426a64f7a2..ff76b76366 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -142,7 +142,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ty = gtbox[i, j, 1] * w - gj tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) - scale = 2.0 - gtbox[i, j, 2] * gtbox[i, j, 3] + scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j] loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale @@ -152,11 +152,14 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): for label_idx in range(class_num): loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos - if label_idx == gtlabel[i, j] else label_neg) + if label_idx == gtlabel[i, j] else + label_neg) * gtscore[i, j] for j in range(mask_num * h * w): - if objness[i, j] >= 0: - loss[i] += sce(pred_obj[i, j], objness[i, j]) + if objness[i, j] > 0: + loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j] + elif objness[i, j] == 0: + loss[i] += sce(pred_obj[i, j], 0.0) return (loss, objness.reshape((n, mask_num, h, w)).astype('float32'), \ gt_matches.astype('int32')) From 042fecefab41a61fdf5f83913b96a039f75b15c5 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 21 Jan 2019 15:04:26 +0800 Subject: [PATCH 19/24] use L2Loss. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 18 ++++++++++--- .../tests/unittests/test_yolov3_loss_op.py | 25 ++++++++++--------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index de01a01a4f..2131289860 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -41,6 +41,11 @@ static T L1Loss(T x, T y) { return std::abs(y - x); } +template +static T L2Loss(T x, T y) { + return 0.5 * (y - x) * (y - x); +} + template static T SCEGrad(T x, T label) { return 1.0 / (1.0 + std::exp(-x)) - label; @@ -51,6 +56,11 @@ static T L1LossGrad(T x, T y) { return x > y ? 1.0 : -1.0; } +template +static T L2LossGrad(T x, T y) { + return x - y; +} + static int GetMaskIndex(std::vector mask, int val) { for (size_t i = 0; i < mask.size(); i++) { if (mask[i] == val) { @@ -130,8 +140,8 @@ static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, T scale = (2.0 - gt.w * gt.h) * score; loss[0] += SCE(input[box_idx], tx) * scale; loss[0] += SCE(input[box_idx + stride], ty) * scale; - loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; - loss[0] += L1Loss(input[box_idx + 3 * stride], th) * scale; + loss[0] += L2Loss(input[box_idx + 2 * stride], tw) * scale; + loss[0] += L2Loss(input[box_idx + 3 * stride], th) * scale; } template @@ -150,9 +160,9 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, input_grad[box_idx + stride] = SCEGrad(input[box_idx + stride], ty) * scale * loss; input_grad[box_idx + 2 * stride] = - L1LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; + L2LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; input_grad[box_idx + 3 * stride] = - L1LossGrad(input[box_idx + 3 * stride], th) * scale * loss; + L2LossGrad(input[box_idx + 3 * stride], th) * scale * loss; } template diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index ff76b76366..0e17eb3130 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -27,6 +27,10 @@ def l1loss(x, y): return abs(x - y) +def l2loss(x, y): + return 0.5 * (y - x) * (y - x) + + def sce(x, label): sigmoid_x = expit(x) term1 = label * np.log(sigmoid_x) @@ -145,8 +149,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j] loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale - loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale - loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale + loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale + loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j] @@ -202,7 +206,7 @@ class TestYolov3LossOp(OpTest): def test_check_output(self): place = core.CPUPlace() - self.check_output_with_place(place, atol=2e-3) + self.check_output_with_place(place, atol=1e-3) def test_check_grad_ignore_gtbox(self): place = core.CPUPlace() @@ -210,19 +214,16 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel", "GTScore"]), - max_relative_error=0.2) + max_relative_error=0.3) def initTestCase(self): - self.anchors = [ - 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, - 373, 326 - ] - self.anchor_mask = [0, 1, 2] - self.class_num = 10 - self.ignore_thresh = 0.7 + self.anchors = [10, 13, 16, 30, 33, 23] + self.anchor_mask = [1, 2] + self.class_num = 5 + self.ignore_thresh = 0.5 self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 10, 4) + self.gtbox_shape = (3, 5, 4) self.use_label_smooth = True From 577424e5ecc47446ee0796794004acf5a5852b19 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 28 Jan 2019 16:53:15 +0800 Subject: [PATCH 20/24] use darknet loss and trick --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 18 ----- paddle/fluid/operators/yolov3_loss_op.h | 72 +++++-------------- python/paddle/fluid/layers/detection.py | 13 ---- .../tests/unittests/test_yolov3_loss_op.py | 35 +++------ 5 files changed, 26 insertions(+), 114 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e71e494f9d..6c6ac9c7ea 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(True, None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 0c5426728b..46374db49a 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -27,8 +27,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("GTLabel"), "Input(GTLabel) of Yolov3LossOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("GTScore"), - "Input(GTScore) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); PADDLE_ENFORCE( @@ -40,7 +38,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtlabel = ctx->GetInputDim("GTLabel"); - auto dim_gtscore = ctx->GetInputDim("GTScore"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); @@ -63,12 +60,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) and Input(GTLabel) dim[0] should be same"); PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], "Input(GTBox) and Input(GTLabel) dim[1] should be same"); - PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2, - "Input(GTScore) should be a 2-D tensor"); - PADDLE_ENFORCE_EQ(dim_gtscore[0], dim_gtbox[0], - "Input(GTBox) and Input(GTScore) dim[0] should be same"); - PADDLE_ENFORCE_EQ(dim_gtscore[1], dim_gtbox[1], - "Input(GTBox) and Input(GTScore) dim[1] should be same"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -121,11 +112,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "This is a 2-D tensor with shape of [N, max_box_num], " "and each element should be an integer to indicate the " "box class id."); - AddInput("GTScore", - "The score of GTLabel, This is a 2-D tensor in same shape " - "GTLabel, and score values should in range (0, 1). This " - "input is for GTLabel score can be not 1.0 in image mixup " - "augmentation."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); @@ -157,8 +143,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss.") .SetDefault(0.7); - AddAttr("use_label_smooth", "bool,default True", "use label smooth") - .SetDefault(true); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. @@ -245,7 +229,6 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("X", Input("X")); op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); - op->SetInput("GTScore", Input("GTScore")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("GTMatchMask", Output("GTMatchMask")); @@ -255,7 +238,6 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {}); - op->SetOutput(framework::GradVarName("GTScore"), {}); return std::unique_ptr(op); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 2131289860..5c9851232d 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -36,11 +36,6 @@ static T SCE(T x, T label) { return (x > 0 ? x : 0.0) - x * label + std::log(1.0 + std::exp(-std::abs(x))); } -template -static T L1Loss(T x, T y) { - return std::abs(y - x); -} - template static T L2Loss(T x, T y) { return 0.5 * (y - x) * (y - x); @@ -51,11 +46,6 @@ static T SCEGrad(T x, T label) { return 1.0 / (1.0 + std::exp(-x)) - label; } -template -static T L1LossGrad(T x, T y) { - return x > y ? 1.0 : -1.0; -} - template static T L2LossGrad(T x, T y) { return x - y; @@ -131,13 +121,13 @@ template static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, int grid_size, - int input_size, int stride, T score) { + int input_size, int stride) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = (2.0 - gt.w * gt.h) * score; + T scale = (2.0 - gt.w * gt.h); loss[0] += SCE(input[box_idx], tx) * scale; loss[0] += SCE(input[box_idx + stride], ty) * scale; loss[0] += L2Loss(input[box_idx + 2 * stride], tw) * scale; @@ -148,14 +138,13 @@ template static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, - int grid_size, int input_size, int stride, - T score) { + int grid_size, int input_size, int stride) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = (2.0 - gt.w * gt.h) * score; + T scale = (2.0 - gt.w * gt.h); input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; input_grad[box_idx + stride] = SCEGrad(input[box_idx + stride], ty) * scale * loss; @@ -168,11 +157,10 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, const T neg, - T score) { + const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? pos : neg) * score; + loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); } } @@ -180,12 +168,11 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, const T neg, - T score) { + const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? pos : neg) * score * loss; + SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; } } @@ -201,7 +188,7 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, T obj = objness[k * w + l]; if (obj > 1e-5) { // positive sample: obj = mixup score - loss[i] += SCE(input[k * w + l], 1.0) * obj; + loss[i] += SCE(input[k * w + l], 1.0); } else if (obj > -0.5) { // negetive sample: obj = 0 loss[i] += SCE(input[k * w + l], 0.0); @@ -226,8 +213,7 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; if (obj > 1e-5) { - input_grad[k * w + l] = - SCEGrad(input[k * w + l], 1.0) * obj * loss[i]; + input_grad[k * w + l] = SCEGrad(input[k * w + l], 1.0) * loss[i]; } else if (obj > -0.5) { input_grad[k * w + l] = SCEGrad(input[k * w + l], 0.0) * loss[i]; } @@ -263,7 +249,6 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); - auto* gt_score = ctx.Input("GTScore"); auto* loss = ctx.Output("Loss"); auto* objness_mask = ctx.Output("ObjectnessMask"); auto* gt_match_mask = ctx.Output("GTMatchMask"); @@ -272,7 +257,6 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); - bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -285,17 +269,9 @@ class Yolov3LossKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; - T label_pos = 1.0; - T label_neg = 0.0; - if (use_label_smooth) { - label_pos = 1.0 - 1.0 / static_cast(class_num); - label_neg = 1.0 / static_cast(class_num); - } - const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); - const T* gt_score_data = gt_score->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); T* obj_mask_data = @@ -376,20 +352,19 @@ class Yolov3LossKernel : public framework::OpKernel { int mask_idx = GetMaskIndex(anchor_mask, best_n); gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { - T score = gt_score_data[i * b + t]; int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, - box_idx, gi, gj, h, input_size, stride, score); + box_idx, gi, gj, h, input_size, stride); int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - obj_mask_data[obj_idx] = score; + obj_mask_data[obj_idx] = 1.0; int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, - class_num, stride, label_pos, label_neg, score); + class_num, stride); } } } @@ -406,7 +381,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); - auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -415,7 +389,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); int downsample = ctx.Attr("downsample"); - bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -428,17 +401,9 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; - T label_pos = 1.0; - T label_neg = 0.0; - if (use_label_smooth) { - label_pos = 1.0 - 1.0 / static_cast(class_num); - label_neg = 1.0 / static_cast(class_num); - } - const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); - const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); const T* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); @@ -450,24 +415,21 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int t = 0; t < b; t++) { int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { - T score = gt_score_data[i * b + t]; Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], - input_data, gt, anchors, - anchor_mask[mask_idx], box_idx, gi, gj, h, - input_size, stride, score); + CalcBoxLocationLossGrad( + input_grad_data, loss_grad_data[i], input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, class_num, stride, label_pos, - label_neg, score); + label_idx, label, class_num, stride); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 07df601697..ea130bb279 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -412,13 +412,11 @@ def polygon_box_transform(input, name=None): def yolov3_loss(x, gtbox, gtlabel, - gtscore, anchors, anchor_mask, class_num, ignore_thresh, downsample, - use_label_smooth=True, name=None): """ ${comment} @@ -432,14 +430,11 @@ def yolov3_loss(x, an image. gtlabel (Variable): class id of ground truth boxes, shoud be in shape of [N, B]. - gtscore (Variable): score of gtlabel, should be in same shape with gtlabel - and score value in range (0, 1). anchors (list|tuple): ${anchors_comment} anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} downsample (int): ${downsample_comment} - use_label_smooth(bool): ${use_label_smooth_comment} name (string): the name of yolov3 loss Returns: @@ -449,11 +444,9 @@ def yolov3_loss(x, TypeError: Input x of yolov3_loss must be Variable TypeError: Input gtbox of yolov3_loss must be Variable" TypeError: Input gtlabel of yolov3_loss must be Variable" - TypeError: Input gtscore of yolov3_loss must be Variable" TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number - TypeError: Attr use_label_smooth of yolov3_loss must be a bool value Examples: .. code-block:: python @@ -474,16 +467,12 @@ def yolov3_loss(x, raise TypeError("Input gtbox of yolov3_loss must be Variable") if not isinstance(gtlabel, Variable): raise TypeError("Input gtlabel of yolov3_loss must be Variable") - if not isinstance(gtscore, Variable): - raise TypeError("Input gtscore of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(use_label_smooth, bool): - raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( "Attr ignore_thresh of yolov3_loss must be a float number") @@ -503,7 +492,6 @@ def yolov3_loss(x, "class_num": class_num, "ignore_thresh": ignore_thresh, "downsample": downsample, - "use_label_smooth": use_label_smooth } helper.append_op( @@ -512,7 +500,6 @@ def yolov3_loss(x, "X": x, "GTBox": gtbox, "GTLabel": gtlabel, - "GTScore": gtscore }, outputs={ 'Loss': loss, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 0e17eb3130..020c113923 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -23,10 +23,6 @@ from op_test import OpTest from paddle.fluid import core -def l1loss(x, y): - return abs(x - y) - - def l2loss(x, y): return 0.5 * (y - x) * (y - x) @@ -70,7 +66,7 @@ def batch_xywh_box_iou(box1, box2): return inter_area / union -def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): +def YOLOv3Loss(x, gtbox, gtlabel, attrs): n, c, h, w = x.shape b = gtbox.shape[1] anchors = attrs['anchors'] @@ -80,14 +76,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] downsample = attrs['downsample'] - use_label_smooth = attrs['use_label_smooth'] input_size = downsample * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') - label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0 - label_neg = 1.0 / class_num if use_label_smooth else 0.0 - pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) @@ -146,22 +138,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ty = gtbox[i, j, 1] * w - gj tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) - scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j] + scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale - objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j] + objness[i, an_idx * h * w + gj * w + gi] = 1.0 for label_idx in range(class_num): - loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos - if label_idx == gtlabel[i, j] else - label_neg) * gtscore[i, j] + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], + float(label_idx == gtlabel[i, j])) for j in range(mask_num * h * w): if objness[i, j] > 0: - loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j] + loss[i] += sce(pred_obj[i, j], 1.0) elif objness[i, j] == 0: loss[i] += sce(pred_obj[i, j], 0.0) @@ -176,7 +167,6 @@ class TestYolov3LossOp(OpTest): x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]) - gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32') gtmask = np.random.randint(0, 2, self.gtbox_shape[:2]) gtbox = gtbox * gtmask[:, :, np.newaxis] gtlabel = gtlabel * gtmask @@ -187,17 +177,14 @@ class TestYolov3LossOp(OpTest): "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, "downsample": self.downsample, - "use_label_smooth": self.use_label_smooth, } self.inputs = { 'X': x, 'GTBox': gtbox.astype('float32'), 'GTLabel': gtlabel.astype('int32'), - 'GTScore': gtscore.astype('float32') } - loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore, - self.attrs) + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) self.outputs = { 'Loss': loss, 'ObjectnessMask': objness, @@ -213,7 +200,7 @@ class TestYolov3LossOp(OpTest): self.check_grad_with_place( place, ['X'], 'Loss', - no_grad_set=set(["GTBox", "GTLabel", "GTScore"]), + no_grad_set=set(["GTBox", "GTLabel"]), max_relative_error=0.3) def initTestCase(self): @@ -224,12 +211,6 @@ class TestYolov3LossOp(OpTest): self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 5, 4) - self.use_label_smooth = True - - -class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): - def set_label_smooth(self): - self.use_label_smooth = False if __name__ == "__main__": From 56e21c558e37395ead098d588902464cb09c206a Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 28 Jan 2019 17:10:47 +0800 Subject: [PATCH 21/24] add comments and docs. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 7 ++++++- paddle/fluid/operators/yolov3_loss_op.h | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 46374db49a..0d13d8fff4 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -98,7 +98,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "This is a 4-D tensor with shape of [N, C, H, W]." "H and W should be same, and the second dimention(C) stores" "box locations, confidence score and classification one-hot" - "key of each anchor box"); + "keys of each anchor box"); AddInput("GTBox", "The input tensor of ground truth boxes, " "This is a 3-D tensor with shape of [N, max_box_num, 5], " @@ -179,6 +179,11 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { box coordinates (w, h), and sigmoid cross entropy loss is used for box coordinates (x, y), confidence score loss and classification loss. + Each groud truth box find a best matching anchor box in all anchors, + prediction of this anchor box will incur all three parts of losses, and + prediction of anchor boxes with no GT box matched will only incur objectness + loss. + In order to trade off box coordinate losses between big boxes and small boxes, box coordinate losses will be mutiplied by scale weight, which is calculated as follow. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 5c9851232d..fce8195668 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -308,13 +308,15 @@ class Yolov3LossKernel : public framework::OpKernel { } } + // If best IoU is greater then ignore_thresh, + // ignore the objectness loss. if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; obj_mask_data[obj_idx] = static_cast(-1); } - // TODO(dengkaipeng): all losses should be calculated if best IoU - // is bigger then truth thresh should be calculated here, but - // currently, truth thresh is an unreachable value as 1.0. + // all losses should be calculated if best IoU + // is bigger then truth thresh, but currently, + // truth thresh is an unreachable value as 1.0. } } } @@ -341,8 +343,6 @@ class Yolov3LossKernel : public framework::OpKernel { an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); float iou = CalcBoxIoU(an_box, gt_shift); - // TODO(dengkaipeng): In paper, objectness loss is ignore when - // best IoU > 0.5, but darknet code didn't implement this. if (iou > best_iou) { best_iou = iou; best_n = an_idx; From ae0b0d5f9362b11fb78355d9d56b7f9ff1cc9c6b Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 28 Jan 2019 22:58:46 +0800 Subject: [PATCH 22/24] fix doc. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 0d13d8fff4..30f0c08463 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -121,7 +121,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "mask for calculate objectness loss in gradient kernel.") .AsIntermediate(); AddOutput("GTMatchMask", - "This is an intermediate tensor with shape if [N, B], " + "This is an intermediate tensor with shape of [N, B], " "B is the max box number of GT boxes. This parameter caches " "matched mask index of each GT boxes for gradient calculate.") .AsIntermediate(); @@ -175,7 +175,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { thresh, the confidence score loss of this anchor box will be ignored. Therefore, the yolov3 loss consist of three major parts, box location loss, - confidence score loss, and classification loss. The L1 loss is used for + confidence score loss, and classification loss. The L2 loss is used for box coordinates (w, h), and sigmoid cross entropy loss is used for box coordinates (x, y), confidence score loss and classification loss. From 733bb82ec0d7ba4bbe9f0ed2aa5c36bc81829fa0 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 29 Jan 2019 14:38:47 +0800 Subject: [PATCH 23/24] downsample -> downsample_ratio. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 2 +- paddle/fluid/operators/yolov3_loss_op.h | 41 +++++++++++++----------- python/paddle/fluid/layers/detection.py | 10 +++--- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6c6ac9c7ea..5fdab448cb 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 30f0c08463..81fd87b4ac 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -135,7 +135,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "The mask index of anchors used in " "current YOLOv3 loss calculation.") .SetDefault(std::vector{}); - AddAttr("downsample", + AddAttr("downsample_ratio", "The downsample ratio from network input to YOLOv3 loss " "input, so 32, 16, 8 should be set for the first, second, " "and thrid YOLOv3 loss operators.") diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index fce8195668..8407d4e6e8 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -32,7 +32,7 @@ static inline bool LessEqualZero(T x) { } template -static T SCE(T x, T label) { +static T SigmoidCrossEntropy(T x, T label) { return (x > 0 ? x : 0.0) - x * label + std::log(1.0 + std::exp(-std::abs(x))); } @@ -42,7 +42,7 @@ static T L2Loss(T x, T y) { } template -static T SCEGrad(T x, T label) { +static T SigmoidCrossEntropyGrad(T x, T label) { return 1.0 / (1.0 + std::exp(-x)) - label; } @@ -62,7 +62,7 @@ static int GetMaskIndex(std::vector mask, int val) { template struct Box { - float x, y, w, h; + T x, y, w, h; }; template @@ -128,8 +128,8 @@ static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); T scale = (2.0 - gt.w * gt.h); - loss[0] += SCE(input[box_idx], tx) * scale; - loss[0] += SCE(input[box_idx + stride], ty) * scale; + loss[0] += SigmoidCrossEntropy(input[box_idx], tx) * scale; + loss[0] += SigmoidCrossEntropy(input[box_idx + stride], ty) * scale; loss[0] += L2Loss(input[box_idx + 2 * stride], tw) * scale; loss[0] += L2Loss(input[box_idx + 3 * stride], th) * scale; } @@ -145,9 +145,10 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); T scale = (2.0 - gt.w * gt.h); - input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; + input_grad[box_idx] = + SigmoidCrossEntropyGrad(input[box_idx], tx) * scale * loss; input_grad[box_idx + stride] = - SCEGrad(input[box_idx + stride], ty) * scale * loss; + SigmoidCrossEntropyGrad(input[box_idx + stride], ty) * scale * loss; input_grad[box_idx + 2 * stride] = L2LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; input_grad[box_idx + 3 * stride] = @@ -160,7 +161,7 @@ static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); + loss[0] += SigmoidCrossEntropy(pred, (i == label) ? 1.0 : 0.0); } } @@ -172,7 +173,7 @@ static inline void CalcLabelLossGrad(T* input_grad, const T loss, for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; + SigmoidCrossEntropyGrad(pred, (i == label) ? 1.0 : 0.0) * loss; } } @@ -187,11 +188,11 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; if (obj > 1e-5) { - // positive sample: obj = mixup score - loss[i] += SCE(input[k * w + l], 1.0); + // positive sample: obj = 1 + loss[i] += SigmoidCrossEntropy(input[k * w + l], 1.0); } else if (obj > -0.5) { // negetive sample: obj = 0 - loss[i] += SCE(input[k * w + l], 0.0); + loss[i] += SigmoidCrossEntropy(input[k * w + l], 0.0); } } } @@ -213,9 +214,11 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; if (obj > 1e-5) { - input_grad[k * w + l] = SCEGrad(input[k * w + l], 1.0) * loss[i]; + input_grad[k * w + l] = + SigmoidCrossEntropyGrad(input[k * w + l], 1.0) * loss[i]; } else if (obj > -0.5) { - input_grad[k * w + l] = SCEGrad(input[k * w + l], 0.0) * loss[i]; + input_grad[k * w + l] = + SigmoidCrossEntropyGrad(input[k * w + l], 0.0) * loss[i]; } } } @@ -256,7 +259,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); - int downsample = ctx.Attr("downsample"); + int downsample_ratio = ctx.Attr("downsample_ratio"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -264,7 +267,7 @@ class Yolov3LossKernel : public framework::OpKernel { const int an_num = anchors.size() / 2; const int mask_num = anchor_mask.size(); const int b = gt_box->dims()[1]; - int input_size = downsample * h; + int input_size = downsample_ratio * h; const int stride = h * w; const int an_stride = (class_num + 5) * stride; @@ -308,7 +311,7 @@ class Yolov3LossKernel : public framework::OpKernel { } } - // If best IoU is greater then ignore_thresh, + // If best IoU is bigger then ignore_thresh, // ignore the objectness loss. if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; @@ -388,7 +391,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - int downsample = ctx.Attr("downsample"); + int downsample_ratio = ctx.Attr("downsample_ratio"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -396,7 +399,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int w = input_grad->dims()[3]; const int mask_num = anchor_mask.size(); const int b = gt_match_mask->dims()[1]; - int input_size = downsample * h; + int input_size = downsample_ratio * h; const int stride = h * w; const int an_stride = (class_num + 5) * stride; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index ea130bb279..486503c871 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -416,7 +416,7 @@ def yolov3_loss(x, anchor_mask, class_num, ignore_thresh, - downsample, + downsample_ratio, name=None): """ ${comment} @@ -434,7 +434,7 @@ def yolov3_loss(x, anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} - downsample (int): ${downsample_comment} + downsample_ratio (int): ${downsample_ratio_comment} name (string): the name of yolov3 loss Returns: @@ -456,8 +456,8 @@ def yolov3_loss(x, gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] anchors = [0, 1, 2] - loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 - anchors=anchors, ignore_thresh=0.5) + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80, anchors=anchors, + ignore_thresh=0.5, downsample_ratio=32) """ helper = LayerHelper('yolov3_loss', **locals()) @@ -491,7 +491,7 @@ def yolov3_loss(x, "anchor_mask": anchor_mask, "class_num": class_num, "ignore_thresh": ignore_thresh, - "downsample": downsample, + "downsample_ratio": downsample_ratio, } helper.append_op( From 23d34d1f7e553bdcf4ac1d270f9e828f8cf99baf Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 29 Jan 2019 16:15:38 +0800 Subject: [PATCH 24/24] move yolov3_loss to detection. test=develop --- paddle/fluid/operators/detection/CMakeLists.txt | 1 + paddle/fluid/operators/{ => detection}/yolov3_loss_op.cc | 2 +- paddle/fluid/operators/{ => detection}/yolov3_loss_op.h | 0 3 files changed, 2 insertions(+), 1 deletion(-) rename paddle/fluid/operators/{ => detection}/yolov3_loss_op.cc (99%) rename paddle/fluid/operators/{ => detection}/yolov3_loss_op.h (100%) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index d3a61dc367..cace42bc1b 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) +detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) if(WITH_GPU) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc similarity index 99% rename from paddle/fluid/operators/yolov3_loss_op.cc rename to paddle/fluid/operators/detection/yolov3_loss_op.cc index 81fd87b4ac..2a69ad4b53 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -9,7 +9,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/yolov3_loss_op.h" +#include "paddle/fluid/operators/detection/yolov3_loss_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/detection/yolov3_loss_op.h similarity index 100% rename from paddle/fluid/operators/yolov3_loss_op.h rename to paddle/fluid/operators/detection/yolov3_loss_op.h