From 5d0b568ecb58d479619c5a2295d65b7f677d4648 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 6 Nov 2018 18:42:19 +0800 Subject: [PATCH 1/8] Add YOLOv3 loss operator. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 130 +++++++++ paddle/fluid/operators/yolov3_loss_op.cu | 23 ++ paddle/fluid/operators/yolov3_loss_op.h | 340 +++++++++++++++++++++++ 3 files changed, 493 insertions(+) create mode 100644 paddle/fluid/operators/yolov3_loss_op.cc create mode 100644 paddle/fluid/operators/yolov3_loss_op.cu create mode 100644 paddle/fluid/operators/yolov3_loss_op.h diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc new file mode 100644 index 0000000000..b4c6a185e2 --- /dev/null +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -0,0 +1,130 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/yolov3_loss_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class Yolov3LossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTBox"), + "Input(GTBox) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of Yolov3LossOp should not be null."); + + // PADDLE_ENFORCE(ctx->HasAttr("img_height"), + // "Attr(img_height) of Yolov3LossOp should not be null. "); + // PADDLE_ENFORCE(ctx->HasAttr("anchors"), + // "Attr(anchor) of Yolov3LossOp should not be null.") + // PADDLE_ENFORCE(ctx->HasAttr("class_num"), + // "Attr(class_num) of Yolov3LossOp should not be null."); + // PADDLE_ENFORCE(ctx->HasAttr( + // "ignore_thresh", + // "Attr(ignore_thresh) of Yolov3LossOp should not be null.")); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_gt = ctx->GetInputDim("GTBox"); + auto img_height = ctx->Attrs().Get("img_height"); + auto anchors = ctx->Attrs().Get>("anchors"); + auto box_num = ctx->Attrs().Get("box_num"); + auto class_num = ctx->Attrs().Get("class_num"); + PADDLE_ENFORCE_GT(img_height, 0, + "Attr(img_height) value should be greater then 0"); + PADDLE_ENFORCE_GT(anchors.size(), 0, + "Attr(anchors) length should be greater then 0."); + PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, + "Attr(anchors) length should be even integer."); + PADDLE_ENFORCE_GT(box_num, 0, + "Attr(box_num) should be an integer greater then 0."); + PADDLE_ENFORCE_GT(class_num, 0, + "Attr(class_num) should be an integer greater then 0."); + PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_number * (5 " + "+ class_num))."); + PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor"); + PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5"); + + std::vector dim_out({dim_x[0], 1}); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + } +}; + +class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of bilinear interpolation, " + "This is a 4-D tensor with shape of [N, C, H, W]"); + AddOutput("Out", + "The output yolo loss tensor, " + "This is a 2-D tensor with shape of [N, 1]"); + + AddAttr("box_num", "The number of boxes generated in each grid."); + AddAttr("class_num", "The number of classes to predict."); + AddComment(R"DOC( + This operator generate yolov3 loss by given predict result and ground + truth boxes. + )DOC"); + } +}; + +class Yolov3LossOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); +REGISTER_OP_CPU_KERNEL( + yolov3_loss, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL( + yolov3_loss_grad, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.cu b/paddle/fluid/operators/yolov3_loss_op.cu new file mode 100644 index 0000000000..48f997456a --- /dev/null +++ b/paddle/fluid/operators/yolov3_loss_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/fluid/operators/yolov3_loss_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + yolov3_loss, + ops::Yolov3LossOpKernel); +REGISTER_OP_CUDA_KERNEL( + yolov3_loss_grad, + ops::Yolov3LossGradOpKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h new file mode 100644 index 0000000000..7950390567 --- /dev/null +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -0,0 +1,340 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenTensor = framework::EigenTensor; +template +using EigenVector = framework::EigenVector; + +using Array2 = Eigen::DSizes; +using Array4 = Eigen::DSizes; + +template +static inline bool isZero(T x) { + return abs(x) < 1e-6; +} + +template +static inline T sigmod(T x) { + return 1.0 / (exp(-1.0 * x) + 1.0); +} + +template +static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, + const Tensor& mask) { + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + auto result = ((x_t - y_t) * mask_t).pow(2).sum().eval(); + return result(0); +} + +template +static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, + const Tensor& mask) { + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + auto result = + ((y_t * (x_t.log()) + (1.0 - y_t) * ((1.0 - x_t).log())) * mask_t) + .sum() + .eval(); + return result; +} + +template +static inline T CalcCEWithMask(const Tensor& x, const Tensor& y, + const Tensor& mask) { + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); +} + +template +static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, + Tensor* pred_confs, Tensor* pred_classes, + Tensor* pred_x, Tensor* pred_y, Tensor* pred_w, + Tensor* pred_h, std::vector anchors, + const int class_num, const int stride) { + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int h = input.dims()[2]; + const int w = input.dims()[3]; + const int anchor_num = anchors.size() / 2; + const int box_attr_num = 5 + class_num; + + auto input_t = EigenTensor::From(input); + auto pred_boxes_t = EigenTensor::From(*pred_boxes); + auto pred_confs_t = EigenTensor::From(*pred_confs); + auto pred_classes_t = EigenTensor::From(*pred_classes); + auto pred_x_t = EigenTensor::From(*pred_x); + auto pred_y_t = EigenTensor::From(*pred_y); + auto pred_w_t = EigenTensor::From(*pred_w); + auto pred_h_t = EigenTensor::From(*pred_h); + + for (int i = 0; i < n; i++) { + for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + float an_w = anchors[an_idx * 2] / stride; + float an_h = anchors[an_idx * 2 + 1] / stride; + + for (int j = 0; j < h; j++) { + for (int k = 0; k < w; k++) { + pred_x_t(i, an_idx, j, k) = + sigmod(input_t(i, box_attr_num * an_idx, j, k)); + pred_y_t(i, an_idx, j, k) = + sigmod(input_t(i, box_attr_num * an_idx + 1, j, k)); + pred_w_t(i, an_idx, j, k) = + sigmod(input_t(i, box_attr_num * an_idx + 2, j, k)); + pred_h_t(i, an_idx, j, k) = + sigmod(input_t(i, box_attr_num * an_idx + 3, j, k)); + + pred_boxes_t(i, an_idx, j, k, 0) = pred_x_t(i, an_idx, j, k) + k; + pred_boxes_t(i, an_idx, j, k, 1) = pred_y_t(i, an_idx, j, k) + j; + pred_boxes_t(i, an_idx, j, k, 2) = + exp(pred_w_t(i, an_idx, j, k)) * an_w; + pred_boxes_t(i, an_idx, j, k, 3) = + exp(pred_h_t(i, an_idx, j, k)) * an_h; + + pred_confs_t(i, an_idx, j, k) = + sigmod(input_t(i, box_attr_num * an_idx + 4, j, k)); + + for (int c = 0; c < class_num; c++) { + pred_classes_t(i, an_idx, j, k, c) = + sigmod(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + } + } + } + } + } +} + +template +static T CalcBoxIoU(std::vector box1, std::vector box2, + bool center_mode) { + T b1_x1, b1_x2, b1_y1, b1_y2; + T b2_x1, b2_x2, b2_y1, b2_y2; + if (center_mode) { + b1_x1 = box1[0] - box1[2] / 2; + b1_x2 = box1[0] + box1[2] / 2; + b1_y1 = box1[1] - box1[3] / 2; + b1_y2 = box1[1] + box1[3] / 2; + b2_x1 = box2[0] - box2[2] / 2; + b2_x2 = box2[0] + box2[2] / 2; + b2_y1 = box2[1] - box2[3] / 2; + b2_y2 = box2[1] + box2[3] / 2; + } else { + b1_x1 = box1[0]; + b1_x2 = box1[1]; + b1_y1 = box1[2]; + b1_y2 = box1[3]; + b2_x1 = box2[0]; + b2_x2 = box2[0]; + b2_y1 = box2[1]; + b2_y2 = box2[1]; + } + T b1_area = (b1_x2 - b1_x1 + 1.0) * (b1_y2 - b1_y1 + 1.0); + T b2_area = (b2_x2 - b2_x1 + 1.0) * (b2_y2 - b2_y1 + 1.0); + + T inter_rect_x1 = std::max(b1_x1, b2_x1); + T inter_rect_y1 = std::max(b1_y1, b2_y1); + T inter_rect_x2 = std::min(b1_x2, b2_x2); + T inter_rect_y2 = std::min(b1_y2, b2_y2); + T inter_area = std::max(inter_rect_x2 - inter_rect_x1 + 1.0, 0.0) * + std::max(inter_rect_y2 - inter_rect_y1 + 1.0, 0.0); + + return inter_area / (b1_area + b2_area - inter_area + 1e-16); +} + +template +static inline int GetPredLabel(const Tensor& pred_classes, int n, + int best_an_index, int gj, int gi) { + auto pred_classes_t = EigenTensor::From(pred_classes); + T score = 0.0; + int label = -1; + for (int i = 0; i < pred_classes.dims()[4]; i++) { + if (pred_classes_t(n, best_an_index, gj, gi, i) > score) { + score = pred_classes_t(n, best_an_index, gj, gi, i); + label = i; + } + } + return label; +} + +template +static void CalcPredBoxWithGTBox( + const Tensor& pred_boxes, const Tensor& pred_confs, + const Tensor& pred_classes, const Tensor& gt_boxes, + std::vector anchors, const float ignore_thresh, const int img_height, + int* gt_num, int* correct_num, Tensor* mask_true, Tensor* mask_false, + Tensor* tx, Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf, + Tensor* tclass) { + const int n = gt_boxes.dims()[0]; + const int b = gt_boxes.dims()[1]; + const int grid_size = pred_boxes.dims()[1]; + const int anchor_num = anchors.size() / 2; + auto pred_boxes_t = EigenTensor::From(pred_boxes); + auto pred_confs_t = EigenTensor::From(pred_confs); + auto pred_classes_t = EigenTensor::From(pred_classes); + auto gt_boxes_t = EigenTensor::From(gt_boxes); + auto mask_true_t = EigenTensor::From(*mask_true).setConstant(0.0); + auto mask_false_t = EigenTensor::From(*mask_false).setConstant(1.0); + auto tx_t = EigenTensor::From(*tx).setConstant(0.0); + auto ty_t = EigenTensor::From(*ty).setConstant(0.0); + auto tw_t = EigenTensor::From(*tw).setConstant(0.0); + auto th_t = EigenTensor::From(*th).setConstant(0.0); + auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); + auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); + + *gt_num = 0; + *correct_num = 0; + for (int i = 0; i < n; i++) { + for (int j = 0; j < b; j++) { + if (isZero(gt_boxes_t(i, j, 0)) && isZero(gt_boxes_t(i, j, 1)) && + isZero(gt_boxes_t(i, j, 2)) && isZero(gt_boxes_t(i, j, 3))) { + continue; + } + + *(gt_num)++; + int gt_label = gt_boxes_t(i, j, 0); + T gx = gt_boxes_t(i, j, 1); + T gy = gt_boxes_t(i, j, 2); + T gw = gt_boxes_t(i, j, 3); + T gh = gt_boxes_t(i, j, 4); + int gi = static_cast(gx); + int gj = static_cast(gy); + + T max_iou = static_cast(-1); + T iou; + int best_an_index = -1; + std::vector gt_box({0, 0, gw, gh}); + for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), + static_cast(anchors[2 * an_idx + 1])}); + iou = CalcBoxIoU(gt_box, anchor_shape, false); + if (iou > max_iou) { + max_iou = iou; + best_an_index = an_idx; + } + if (iou > ignore_thresh) { + mask_false_t(b, an_idx, gj, gi) = 0; + } + } + mask_true_t(b, best_an_index, gj, gi) = 1; + mask_false_t(b, best_an_index, gj, gi) = 1; + tx_t(i, best_an_index, gj, gi) = gx - gi; + ty_t(i, best_an_index, gj, gi) = gy - gj; + tw_t(i, best_an_index, gj, gi) = + log(gw / anchors[2 * best_an_index] + 1e-16); + th_t(i, best_an_index, gj, gi) = + log(gh / anchors[2 * best_an_index + 1] + 1e-16); + tclass_t(b, best_an_index, gj, gi, gt_label) = 1; + tconf_t(b, best_an_index, gj, gi) = 1; + + std::vector pred_box({ + pred_boxes_t(i, best_an_index, gj, gi, 0), + pred_boxes_t(i, best_an_index, gj, gi, 1), + pred_boxes_t(i, best_an_index, gj, gi, 2), + pred_boxes_t(i, best_an_index, gj, gi, 3), + }); + gt_box[0] = gx; + gt_box[1] = gy; + iou = CalcBoxIoU(gt_box, pred_box, true); + int pred_label = GetPredLabel(pred_classes, i, best_an_index, gj, gi); + T score = pred_confs_t(i, best_an_index, gj, gi); + if (iou > 0.5 && pred_label == gt_label && score > 0.5) { + (*correct_num)++; + } + } + } + mask_false_t = mask_true_t - mask_false_t; +} + +template +class Yolov3LossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* gt_boxes = ctx.Input("GTBox"); + auto* output = ctx.Output("Out"); + int img_height = ctx.Attr("img_height"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float ignore_thresh = ctx.Attr("ignore_thresh"); + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int an_num = anchors.size() / 2; + const float stride = static_cast(img_height) / h; + + Tensor pred_x, pred_y, pred_w, pred_h; + Tensor pred_boxes, pred_confs, pred_classes; + pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_boxes.mutable_data({n, an_num, h, w, 4}, ctx.GetPlace()); + pred_confs.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_classes.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_boxes, &pred_confs, &pred_classes, &pred_x, + &pred_y, &pred_w, &pred_h, anchors, class_num, stride); + + Tensor mask_true, mask_false; + Tensor tx, ty, tw, th, tconf, tclass; + mask_true.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + mask_false.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + int gt_num = 0; + int correct_num = 0; + CalcPredBoxWithGTBox(pred_boxes, pred_confs, pred_classes, *gt_boxes, + anchors, ignore_thresh, img_height, >_num, + &correct_num, &mask_true, &mask_false, &tx, &ty, + &tw, &th, &tconf, &tclass); + + T loss_x = CalcMSEWithMask(pred_x, tx, mask_true); + T loss_y = CalcMSEWithMask(pred_y, ty, mask_true); + T loss_w = CalcMSEWithMask(pred_w, tw, mask_true); + T loss_h = CalcMSEWithMask(pred_h, th, mask_true); + T loss_conf_true = CalcBCEWithMask(pred_confs, tconf, mask_true); + T loss_conf_false = CalcBCEWithMask(pred_confs, tconf, mask_false); + // T loss_class = CalcCEWithMask() + } +}; + +template +class Yolov3LossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_input_t = ctx.Output(framework::GradVarName("X")); + auto* d_output_t = ctx.Input(framework::GradVarName("Out")); + } +}; + +} // namespace operators +} // namespace paddle From 77c1328fa749c900c7e12bd6b9d70e84b91d5f49 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sat, 10 Nov 2018 23:32:11 +0800 Subject: [PATCH 2/8] add CPU kernel forward --- paddle/fluid/operators/yolov3_loss_op.cc | 60 ++++--- paddle/fluid/operators/yolov3_loss_op.h | 215 ++++++++++------------- 2 files changed, 127 insertions(+), 148 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index b4c6a185e2..9ed7e13dc7 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -27,18 +27,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(X) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("GTBox"), "Input(GTBox) of Yolov3LossOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of Yolov3LossOp should not be null."); - - // PADDLE_ENFORCE(ctx->HasAttr("img_height"), - // "Attr(img_height) of Yolov3LossOp should not be null. "); - // PADDLE_ENFORCE(ctx->HasAttr("anchors"), - // "Attr(anchor) of Yolov3LossOp should not be null.") - // PADDLE_ENFORCE(ctx->HasAttr("class_num"), - // "Attr(class_num) of Yolov3LossOp should not be null."); - // PADDLE_ENFORCE(ctx->HasAttr( - // "ignore_thresh", - // "Attr(ignore_thresh) of Yolov3LossOp should not be null.")); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of Yolov3LossOp should not be null."); auto dim_x = ctx->GetInputDim("X"); auto dim_gt = ctx->GetInputDim("GTBox"); @@ -46,6 +36,14 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto anchors = ctx->Attrs().Get>("anchors"); auto box_num = ctx->Attrs().Get("box_num"); auto class_num = ctx->Attrs().Get("class_num"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); + PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], + "Input(X) dim[3] and dim[4] should be euqal."); + PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_number * (5 " + "+ class_num))."); + PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor"); + PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5"); PADDLE_ENFORCE_GT(img_height, 0, "Attr(img_height) value should be greater then 0"); PADDLE_ENFORCE_GT(anchors.size(), 0, @@ -56,14 +54,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Attr(box_num) should be an integer greater then 0."); PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); - PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), - "Input(X) dim[1] should be equal to (anchor_number * (5 " - "+ class_num))."); - PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor"); - PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5"); - std::vector dim_out({dim_x[0], 1}); - ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); + std::vector dim_out({1}); + ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); } protected: @@ -80,12 +73,31 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor of bilinear interpolation, " "This is a 4-D tensor with shape of [N, C, H, W]"); - AddOutput("Out", - "The output yolo loss tensor, " - "This is a 2-D tensor with shape of [N, 1]"); + AddInput( + "GTBox", + "The input tensor of ground truth boxes, " + "This is a 3-D tensor with shape of [N, max_box_num, 5 + class_num], " + "max_box_num is the max number of boxes in each image, " + "class_num is the number of classes in data set. " + "In the third dimention, stores x, y, w, h, confidence, classes " + "one-hot key. " + "x, y is the center cordinate of boxes and w, h is the width and " + "height, " + "and all of them should be divided by input image height to scale to " + "[0, 1]."); + AddOutput("Loss", + "The output yolov3 loss tensor, " + "This is a 1-D tensor with shape of [1]"); AddAttr("box_num", "The number of boxes generated in each grid."); AddAttr("class_num", "The number of classes to predict."); + AddAttr>("anchors", + "The anchor width and height, " + "it will be parsed pair by pair."); + AddAttr("img_height", + "The input image height after crop of yolov3 network."); + AddAttr("ignore_thresh", + "The ignore threshold to ignore confidence loss."); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. @@ -100,8 +112,8 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss@GRAD) should not be null"); auto dim_x = ctx->GetInputDim("X"); if (ctx->HasOutput(framework::GradVarName("X"))) { ctx->SetOutputDim(framework::GradVarName("X"), dim_x); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 7950390567..a796a57809 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -44,8 +44,16 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, auto x_t = EigenVector::Flatten(x); auto y_t = EigenVector::Flatten(y); auto mask_t = EigenVector::Flatten(mask); - auto result = ((x_t - y_t) * mask_t).pow(2).sum().eval(); - return result(0); + + T error_sum = 0.0; + T points = 0.0; + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + error_sum += pow(x_t(i) - y_t(i), 2); + points += 1; + } + } + return (error_sum / points); } template @@ -55,27 +63,24 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, auto y_t = EigenVector::Flatten(y); auto mask_t = EigenVector::Flatten(mask); - auto result = - ((y_t * (x_t.log()) + (1.0 - y_t) * ((1.0 - x_t).log())) * mask_t) - .sum() - .eval(); - return result; -} - -template -static inline T CalcCEWithMask(const Tensor& x, const Tensor& y, - const Tensor& mask) { - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); + T error_sum = 0.0; + T points = 0.0; + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + error_sum += + -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i))); + points += 1; + } + } + return (error_sum / points); } template -static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, - Tensor* pred_confs, Tensor* pred_classes, - Tensor* pred_x, Tensor* pred_y, Tensor* pred_w, - Tensor* pred_h, std::vector anchors, - const int class_num, const int stride) { +static void CalcPredResult(const Tensor& input, Tensor* pred_confs, + Tensor* pred_classes, Tensor* pred_x, Tensor* pred_y, + Tensor* pred_w, Tensor* pred_h, + std::vector anchors, const int class_num, + const int stride) { const int n = input.dims()[0]; const int c = input.dims()[1]; const int h = input.dims()[2]; @@ -84,7 +89,7 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, const int box_attr_num = 5 + class_num; auto input_t = EigenTensor::From(input); - auto pred_boxes_t = EigenTensor::From(*pred_boxes); + // auto pred_boxes_t = EigenTensor::From(*pred_boxes); auto pred_confs_t = EigenTensor::From(*pred_confs); auto pred_classes_t = EigenTensor::From(*pred_classes); auto pred_x_t = EigenTensor::From(*pred_x); @@ -104,16 +109,16 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, pred_y_t(i, an_idx, j, k) = sigmod(input_t(i, box_attr_num * an_idx + 1, j, k)); pred_w_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx + 2, j, k)); + input_t(i, box_attr_num * an_idx + 2, j, k); pred_h_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx + 3, j, k)); + input_t(i, box_attr_num * an_idx + 3, j, k); - pred_boxes_t(i, an_idx, j, k, 0) = pred_x_t(i, an_idx, j, k) + k; - pred_boxes_t(i, an_idx, j, k, 1) = pred_y_t(i, an_idx, j, k) + j; - pred_boxes_t(i, an_idx, j, k, 2) = - exp(pred_w_t(i, an_idx, j, k)) * an_w; - pred_boxes_t(i, an_idx, j, k, 3) = - exp(pred_h_t(i, an_idx, j, k)) * an_h; + // pred_boxes_t(i, an_idx, j, k, 0) = pred_x_t(i, an_idx, j, k) + k; + // pred_boxes_t(i, an_idx, j, k, 1) = pred_y_t(i, an_idx, j, k) + j; + // pred_boxes_t(i, an_idx, j, k, 2) = + // exp(pred_w_t(i, an_idx, j, k)) * an_w; + // pred_boxes_t(i, an_idx, j, k, 3) = + // exp(pred_h_t(i, an_idx, j, k)) * an_h; pred_confs_t(i, an_idx, j, k) = sigmod(input_t(i, box_attr_num * an_idx + 4, j, k)); @@ -129,40 +134,27 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_boxes, } template -static T CalcBoxIoU(std::vector box1, std::vector box2, - bool center_mode) { - T b1_x1, b1_x2, b1_y1, b1_y2; - T b2_x1, b2_x2, b2_y1, b2_y2; - if (center_mode) { - b1_x1 = box1[0] - box1[2] / 2; - b1_x2 = box1[0] + box1[2] / 2; - b1_y1 = box1[1] - box1[3] / 2; - b1_y2 = box1[1] + box1[3] / 2; - b2_x1 = box2[0] - box2[2] / 2; - b2_x2 = box2[0] + box2[2] / 2; - b2_y1 = box2[1] - box2[3] / 2; - b2_y2 = box2[1] + box2[3] / 2; - } else { - b1_x1 = box1[0]; - b1_x2 = box1[1]; - b1_y1 = box1[2]; - b1_y2 = box1[3]; - b2_x1 = box2[0]; - b2_x2 = box2[0]; - b2_y1 = box2[1]; - b2_y2 = box2[1]; - } - T b1_area = (b1_x2 - b1_x1 + 1.0) * (b1_y2 - b1_y1 + 1.0); - T b2_area = (b2_x2 - b2_x1 + 1.0) * (b2_y2 - b2_y1 + 1.0); +static T CalcBoxIoU(std::vector box1, std::vector box2) { + T b1_x1 = box1[0] - box1[2] / 2; + T b1_x2 = box1[0] + box1[2] / 2; + T b1_y1 = box1[1] - box1[3] / 2; + T b1_y2 = box1[1] + box1[3] / 2; + T b2_x1 = box2[0] - box2[2] / 2; + T b2_x2 = box2[0] + box2[2] / 2; + T b2_y1 = box2[1] - box2[3] / 2; + T b2_y2 = box2[1] + box2[3] / 2; + + T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1); + T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1); T inter_rect_x1 = std::max(b1_x1, b2_x1); T inter_rect_y1 = std::max(b1_y1, b2_y1); T inter_rect_x2 = std::min(b1_x2, b2_x2); T inter_rect_y2 = std::min(b1_y2, b2_y2); - T inter_area = std::max(inter_rect_x2 - inter_rect_x1 + 1.0, 0.0) * - std::max(inter_rect_y2 - inter_rect_y1 + 1.0, 0.0); + T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast(0.0)) * + std::max(inter_rect_y2 - inter_rect_y1, static_cast(0.0)); - return inter_area / (b1_area + b2_area - inter_area + 1e-16); + return inter_area / (b1_area + b2_area - inter_area); } template @@ -181,23 +173,18 @@ static inline int GetPredLabel(const Tensor& pred_classes, int n, } template -static void CalcPredBoxWithGTBox( - const Tensor& pred_boxes, const Tensor& pred_confs, - const Tensor& pred_classes, const Tensor& gt_boxes, - std::vector anchors, const float ignore_thresh, const int img_height, - int* gt_num, int* correct_num, Tensor* mask_true, Tensor* mask_false, - Tensor* tx, Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf, - Tensor* tclass) { +static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, + std::vector anchors, const int img_height, + const int grid_size, Tensor* obj_mask, + Tensor* noobj_mask, Tensor* tx, Tensor* ty, + Tensor* tw, Tensor* th, Tensor* tconf, + Tensor* tclass) { const int n = gt_boxes.dims()[0]; const int b = gt_boxes.dims()[1]; - const int grid_size = pred_boxes.dims()[1]; const int anchor_num = anchors.size() / 2; - auto pred_boxes_t = EigenTensor::From(pred_boxes); - auto pred_confs_t = EigenTensor::From(pred_confs); - auto pred_classes_t = EigenTensor::From(pred_classes); auto gt_boxes_t = EigenTensor::From(gt_boxes); - auto mask_true_t = EigenTensor::From(*mask_true).setConstant(0.0); - auto mask_false_t = EigenTensor::From(*mask_false).setConstant(1.0); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); + auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); auto tx_t = EigenTensor::From(*tx).setConstant(0.0); auto ty_t = EigenTensor::From(*ty).setConstant(0.0); auto tw_t = EigenTensor::From(*tw).setConstant(0.0); @@ -205,8 +192,6 @@ static void CalcPredBoxWithGTBox( auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); - *gt_num = 0; - *correct_num = 0; for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { if (isZero(gt_boxes_t(i, j, 0)) && isZero(gt_boxes_t(i, j, 1)) && @@ -214,12 +199,11 @@ static void CalcPredBoxWithGTBox( continue; } - *(gt_num)++; int gt_label = gt_boxes_t(i, j, 0); - T gx = gt_boxes_t(i, j, 1); - T gy = gt_boxes_t(i, j, 2); - T gw = gt_boxes_t(i, j, 3); - T gh = gt_boxes_t(i, j, 4); + T gx = gt_boxes_t(i, j, 1) * grid_size; + T gy = gt_boxes_t(i, j, 2) * grid_size; + T gw = gt_boxes_t(i, j, 3) * grid_size; + T gh = gt_boxes_t(i, j, 4) * grid_size; int gi = static_cast(gx); int gj = static_cast(gy); @@ -230,43 +214,26 @@ static void CalcPredBoxWithGTBox( for (int an_idx = 0; an_idx < anchor_num; an_idx++) { std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), static_cast(anchors[2 * an_idx + 1])}); - iou = CalcBoxIoU(gt_box, anchor_shape, false); + iou = CalcBoxIoU(gt_box, anchor_shape); if (iou > max_iou) { max_iou = iou; best_an_index = an_idx; } if (iou > ignore_thresh) { - mask_false_t(b, an_idx, gj, gi) = 0; + noobj_mask_t(b, an_idx, gj, gi) = 0; } } - mask_true_t(b, best_an_index, gj, gi) = 1; - mask_false_t(b, best_an_index, gj, gi) = 1; + obj_mask_t(b, best_an_index, gj, gi) = 1; + noobj_mask_t(b, best_an_index, gj, gi) = 1; tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; - tw_t(i, best_an_index, gj, gi) = - log(gw / anchors[2 * best_an_index] + 1e-16); - th_t(i, best_an_index, gj, gi) = - log(gh / anchors[2 * best_an_index + 1] + 1e-16); + tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); + th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); tclass_t(b, best_an_index, gj, gi, gt_label) = 1; tconf_t(b, best_an_index, gj, gi) = 1; - - std::vector pred_box({ - pred_boxes_t(i, best_an_index, gj, gi, 0), - pred_boxes_t(i, best_an_index, gj, gi, 1), - pred_boxes_t(i, best_an_index, gj, gi, 2), - pred_boxes_t(i, best_an_index, gj, gi, 3), - }); - gt_box[0] = gx; - gt_box[1] = gy; - iou = CalcBoxIoU(gt_box, pred_box, true); - int pred_label = GetPredLabel(pred_classes, i, best_an_index, gj, gi); - T score = pred_confs_t(i, best_an_index, gj, gi); - if (iou > 0.5 && pred_label == gt_label && score > 0.5) { - (*correct_num)++; - } } } - mask_false_t = mask_true_t - mask_false_t; + noobj_mask_t = noobj_mask_t - obj_mask_t; } template @@ -275,7 +242,7 @@ class Yolov3LossKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); auto* gt_boxes = ctx.Input("GTBox"); - auto* output = ctx.Output("Out"); + auto* loss = ctx.Output("Loss"); int img_height = ctx.Attr("img_height"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); @@ -286,44 +253,44 @@ class Yolov3LossKernel : public framework::OpKernel { const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - const float stride = static_cast(img_height) / h; + const T stride = static_cast(img_height) / h; Tensor pred_x, pred_y, pred_w, pred_h; - Tensor pred_boxes, pred_confs, pred_classes; + Tensor pred_confs, pred_classes; pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_boxes.mutable_data({n, an_num, h, w, 4}, ctx.GetPlace()); pred_confs.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_classes.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcPredResult(*input, &pred_boxes, &pred_confs, &pred_classes, &pred_x, - &pred_y, &pred_w, &pred_h, anchors, class_num, stride); + CalcPredResult(*input, &pred_confs, &pred_classes, &pred_x, &pred_y, + &pred_w, &pred_h, anchors, class_num, stride); - Tensor mask_true, mask_false; + Tensor obj_mask, noobj_mask; Tensor tx, ty, tw, th, tconf, tclass; - mask_true.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - mask_false.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - int gt_num = 0; - int correct_num = 0; - CalcPredBoxWithGTBox(pred_boxes, pred_confs, pred_classes, *gt_boxes, - anchors, ignore_thresh, img_height, >_num, - &correct_num, &mask_true, &mask_false, &tx, &ty, - &tw, &th, &tconf, &tclass); - - T loss_x = CalcMSEWithMask(pred_x, tx, mask_true); - T loss_y = CalcMSEWithMask(pred_y, ty, mask_true); - T loss_w = CalcMSEWithMask(pred_w, tw, mask_true); - T loss_h = CalcMSEWithMask(pred_h, th, mask_true); - T loss_conf_true = CalcBCEWithMask(pred_confs, tconf, mask_true); - T loss_conf_false = CalcBCEWithMask(pred_confs, tconf, mask_false); - // T loss_class = CalcCEWithMask() + PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, img_height, h, + &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, + &tclass); + + T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); + T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); + T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); + T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); + T loss_conf_true = CalcBCEWithMask(pred_confs, tconf, obj_mask); + T loss_conf_false = CalcBCEWithMask(pred_confs, tconf, noobj_mask); + T loss_class = CalcBCEWithMask(pred_classes, tclass, obj_mask); + + auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); + loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_true + + loss_conf_false + loss_class; } }; From 36c46152e140adab7e74eaeee9dbeccb65fc5633 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sun, 11 Nov 2018 23:52:36 +0800 Subject: [PATCH 3/8] Add unittest for yolov3_loss. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 25 +-- paddle/fluid/operators/yolov3_loss_op.h | 67 +++--- python/paddle/fluid/layers/nn.py | 28 +++ .../tests/unittests/test_yolov3_loss_op.py | 194 ++++++++++++++++++ 4 files changed, 273 insertions(+), 41 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 9ed7e13dc7..7369ce31e8 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -34,7 +34,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_gt = ctx->GetInputDim("GTBox"); auto img_height = ctx->Attrs().Get("img_height"); auto anchors = ctx->Attrs().Get>("anchors"); - auto box_num = ctx->Attrs().Get("box_num"); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], @@ -50,8 +49,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, "Attr(anchors) length should be even integer."); - PADDLE_ENFORCE_GT(box_num, 0, - "Attr(box_num) should be an integer greater then 0."); PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); @@ -73,23 +70,19 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor of bilinear interpolation, " "This is a 4-D tensor with shape of [N, C, H, W]"); - AddInput( - "GTBox", - "The input tensor of ground truth boxes, " - "This is a 3-D tensor with shape of [N, max_box_num, 5 + class_num], " - "max_box_num is the max number of boxes in each image, " - "class_num is the number of classes in data set. " - "In the third dimention, stores x, y, w, h, confidence, classes " - "one-hot key. " - "x, y is the center cordinate of boxes and w, h is the width and " - "height, " - "and all of them should be divided by input image height to scale to " - "[0, 1]."); + AddInput("GTBox", + "The input tensor of ground truth boxes, " + "This is a 3-D tensor with shape of [N, max_box_num, 5], " + "max_box_num is the max number of boxes in each image, " + "In the third dimention, stores label, x, y, w, h, " + "label is an integer to specify box class, x, y is the " + "center cordinate of boxes and w, h is the width and height" + "and x, y, w, h should be divided by input image height to " + "scale to [0, 1]."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [1]"); - AddAttr("box_num", "The number of boxes generated in each grid."); AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", "The anchor width and height, " diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index a796a57809..426e0688ab 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -25,8 +25,7 @@ template using EigenVector = framework::EigenVector; -using Array2 = Eigen::DSizes; -using Array4 = Eigen::DSizes; +using Array5 = Eigen::DSizes; template static inline bool isZero(T x) { @@ -43,7 +42,7 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { auto x_t = EigenVector::Flatten(x); auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); + auto mask_t = EigenVector::Flatten(mask); T error_sum = 0.0; T points = 0.0; @@ -61,7 +60,7 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { auto x_t = EigenVector::Flatten(x); auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); + auto mask_t = EigenVector::Flatten(mask); T error_sum = 0.0; T points = 0.0; @@ -89,7 +88,6 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs, const int box_attr_num = 5 + class_num; auto input_t = EigenTensor::From(input); - // auto pred_boxes_t = EigenTensor::From(*pred_boxes); auto pred_confs_t = EigenTensor::From(*pred_confs); auto pred_classes_t = EigenTensor::From(*pred_classes); auto pred_x_t = EigenTensor::From(*pred_x); @@ -113,13 +111,6 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs, pred_h_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 3, j, k); - // pred_boxes_t(i, an_idx, j, k, 0) = pred_x_t(i, an_idx, j, k) + k; - // pred_boxes_t(i, an_idx, j, k, 1) = pred_y_t(i, an_idx, j, k) + j; - // pred_boxes_t(i, an_idx, j, k, 2) = - // exp(pred_w_t(i, an_idx, j, k)) * an_w; - // pred_boxes_t(i, an_idx, j, k, 3) = - // exp(pred_h_t(i, an_idx, j, k)) * an_h; - pred_confs_t(i, an_idx, j, k) = sigmod(input_t(i, box_attr_num * an_idx + 4, j, k)); @@ -199,7 +190,7 @@ static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, continue; } - int gt_label = gt_boxes_t(i, j, 0); + int gt_label = static_cast(gt_boxes_t(i, j, 0)); T gx = gt_boxes_t(i, j, 1) * grid_size; T gy = gt_boxes_t(i, j, 2) * grid_size; T gw = gt_boxes_t(i, j, 3) * grid_size; @@ -207,7 +198,7 @@ static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, int gi = static_cast(gx); int gj = static_cast(gy); - T max_iou = static_cast(-1); + T max_iou = static_cast(0); T iou; int best_an_index = -1; std::vector gt_box({0, 0, gw, gh}); @@ -220,20 +211,33 @@ static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, best_an_index = an_idx; } if (iou > ignore_thresh) { - noobj_mask_t(b, an_idx, gj, gi) = 0; + noobj_mask_t(i, an_idx, gj, gi) = 0; } } - obj_mask_t(b, best_an_index, gj, gi) = 1; - noobj_mask_t(b, best_an_index, gj, gi) = 1; + obj_mask_t(i, best_an_index, gj, gi) = 1; + noobj_mask_t(i, best_an_index, gj, gi) = 0; tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); - tclass_t(b, best_an_index, gj, gi, gt_label) = 1; - tconf_t(b, best_an_index, gj, gi) = 1; + tclass_t(i, best_an_index, gj, gi, gt_label) = 1; + tconf_t(i, best_an_index, gj, gi) = 1; } } - noobj_mask_t = noobj_mask_t - obj_mask_t; +} + +static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, + const Tensor& obj_mask) { + const int n = obj_mask_expand->dims()[0]; + const int an_num = obj_mask_expand->dims()[1]; + const int h = obj_mask_expand->dims()[2]; + const int w = obj_mask_expand->dims()[3]; + const int class_num = obj_mask_expand->dims()[4]; + auto obj_mask_expand_t = EigenTensor::From(*obj_mask_expand); + auto obj_mask_t = EigenTensor::From(obj_mask); + + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); } template @@ -280,17 +284,30 @@ class Yolov3LossKernel : public framework::OpKernel { &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_true = CalcBCEWithMask(pred_confs, tconf, obj_mask); - T loss_conf_false = CalcBCEWithMask(pred_confs, tconf, noobj_mask); - T loss_class = CalcBCEWithMask(pred_classes, tclass, obj_mask); + T loss_conf_obj = CalcBCEWithMask(pred_confs, tconf, obj_mask); + T loss_conf_noobj = CalcBCEWithMask(pred_confs, tconf, noobj_mask); + T loss_class = CalcBCEWithMask(pred_classes, tclass, obj_mask_expand); + + // LOG(ERROR) << "loss_x: " << loss_x; + // LOG(ERROR) << "loss_y: " << loss_y; + // LOG(ERROR) << "loss_w: " << loss_w; + // LOG(ERROR) << "loss_h: " << loss_h; + // LOG(ERROR) << "loss_conf_obj: " << loss_conf_obj; + // LOG(ERROR) << "loss_conf_noobj: " << loss_conf_noobj; + // LOG(ERROR) << "loss_class: " << loss_class; auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); - loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_true + - loss_conf_false + loss_class; + loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_obj + + loss_conf_noobj + loss_class; } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d3623464e9..1ee7198f29 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -164,6 +164,7 @@ __all__ = [ 'hash', 'grid_sampler', 'log_loss', + 'yolov3_loss', 'add_position_encoding', 'bilinear_tensor_product', ] @@ -8243,6 +8244,33 @@ def log_loss(input, label, epsilon=1e-4, name=None): return loss +def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None): + """ + **YOLOv3 Loss Layer** + + This layer + """ + helper = LayerHelper('yolov3_loss', **locals()) + + if name is None: + loss = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + loss = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type='yolov3_loss', + inputs={'X': x, + "GTBox": gtbox}, + outputs={'Loss': loss}, + attrs={ + "img_height": img_height, + "anchors": anchors, + "ignore_thresh": ignore_thresh, + }) + return loss + + def add_position_encoding(input, alpha, beta, name=None): """ **Add Position Encoding Layer** diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py new file mode 100644 index 0000000000..f5b15efb27 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -0,0 +1,194 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def mse(x, y, num): + return ((y - x)**2).sum() / num + + +def bce(x, y, mask): + x = x.reshape((-1)) + y = y.reshape((-1)) + mask = mask.reshape((-1)) + + error_sum = 0.0 + count = 0 + for i in range(x.shape[0]): + if mask[i] > 0: + error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i]) + count += 1 + return error_sum / (-1.0 * count) + + +def box_iou(box1, box2): + b1_x1 = box1[0] - box1[2] / 2 + b1_x2 = box1[0] + box1[2] / 2 + b1_y1 = box1[1] - box1[3] / 2 + b1_y2 = box1[1] + box1[3] / 2 + b2_x1 = box2[0] - box2[2] / 2 + b2_x2 = box2[0] + box2[2] / 2 + b2_y1 = box2[1] - box2[3] / 2 + b2_y2 = box2[1] + box2[3] / 2 + + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + + inter_rect_x1 = max(b1_x1, b2_x1) + inter_rect_y1 = max(b1_y1, b2_y1) + inter_rect_x2 = min(b1_x2, b2_x2) + inter_rect_y2 = min(b1_y2, b2_y2) + inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max( + inter_rect_y2 - inter_rect_y1, 0) + + return inter_area / (b1_area + b2_area + inter_area) + + +def build_target(gtboxs, attrs, grid_size): + n, b, _ = gtboxs.shape + ignore_thresh = attrs["ignore_thresh"] + img_height = attrs["img_height"] + anchors = attrs["anchors"] + class_num = attrs["class_num"] + an_num = len(anchors) / 2 + obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') + tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tcls = np.zeros( + (n, an_num, grid_size, grid_size, class_num)).astype('float32') + + for i in range(n): + for j in range(b): + if gtboxs[i, j, :].sum() == 0: + continue + + gt_label = int(gtboxs[i, j, 0]) + gx = gtboxs[i, j, 1] * grid_size + gy = gtboxs[i, j, 2] * grid_size + gw = gtboxs[i, j, 3] * grid_size + gh = gtboxs[i, j, 4] * grid_size + + gi = int(gx) + gj = int(gy) + + gtbox = [0, 0, gw, gh] + max_iou = 0 + for k in range(an_num): + anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]] + iou = box_iou(gtbox, anchor_box) + if iou > max_iou: + max_iou = iou + best_an_index = k + if iou > ignore_thresh: + noobj_mask[i, best_an_index, gj, gi] = 0 + + obj_mask[i, best_an_index, gj, gi] = 1 + noobj_mask[i, best_an_index, gj, gi] = 0 + tx[i, best_an_index, gj, gi] = gx - gi + ty[i, best_an_index, gj, gi] = gy - gj + tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * + best_an_index]) + th[i, best_an_index, gj, gi] = np.log( + gh / anchors[2 * best_an_index + 1]) + tconf[i, best_an_index, gj, gi] = 1 + tcls[i, best_an_index, gj, gi, gt_label] = 1 + + return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) + + +def YoloV3Loss(x, gtbox, attrs): + n, c, h, w = x.shape + an_num = len(attrs['anchors']) / 2 + class_num = attrs["class_num"] + x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + pred_x = sigmoid(x[:, :, :, :, 0]) + pred_y = sigmoid(x[:, :, :, :, 1]) + pred_w = x[:, :, :, :, 2] + pred_h = x[:, :, :, :, 3] + pred_conf = sigmoid(x[:, :, :, :, 4]) + pred_cls = sigmoid(x[:, :, :, :, 5:]) + + tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( + gtbox, attrs, x.shape[2]) + + obj_mask_expand = np.tile( + np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) + loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum()) + loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) + loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) + loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) + loss_conf_obj = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) + loss_conf_noobj = bce(pred_conf * noobj_mask, tconf * noobj_mask, + noobj_mask) + loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, + obj_mask_expand) + # print "loss_x: ", loss_x + # print "loss_y: ", loss_y + # print "loss_w: ", loss_w + # print "loss_h: ", loss_h + # print "loss_conf_obj: ", loss_conf_obj + # print "loss_conf_noobj: ", loss_conf_noobj + # print "loss_class: ", loss_class + + return loss_x + loss_y + loss_w + loss_h + loss_conf_obj + loss_conf_noobj + loss_class + + +class TestYolov3LossOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'yolov3_loss' + x = np.random.random(size=self.x_shape).astype('float32') + gtbox = np.random.random(size=self.gtbox_shape).astype('float32') + gtbox[:, :, 0] = np.random.randint(0, self.class_num, + self.gtbox_shape[:2]) + + self.attrs = { + "img_height": self.img_height, + "anchors": self.anchors, + "class_num": self.class_num, + "ignore_thresh": self.ignore_thresh, + } + + self.inputs = {'X': x, 'GTBox': gtbox} + self.outputs = {'Loss': np.array([YoloV3Loss(x, gtbox, self.attrs)])} + print self.outputs + + def test_check_output(self): + self.check_output(atol=1e-3) + + # def test_check_grad_normal(self): + # self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61) + + def initTestCase(self): + self.img_height = 608 + self.anchors = [10, 13, 16, 30, 33, 23] + self.class_num = 10 + self.ignore_thresh = 0.5 + self.x_shape = (5, len(self.anchors) / 2 * (5 + self.class_num), 7, 7) + self.gtbox_shape = (5, 10, 5) + + +if __name__ == "__main__": + unittest.main() From a0284f6fbcb4888e1653b7f094db615f1437943c Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 12 Nov 2018 21:13:25 +0800 Subject: [PATCH 4/8] Add backward CPU kernel. test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/yolov3_loss_op.cc | 64 ++++- paddle/fluid/operators/yolov3_loss_op.cu | 4 +- paddle/fluid/operators/yolov3_loss_op.h | 256 +++++++++++++----- python/paddle/fluid/layers/nn.py | 49 +++- .../fluid/tests/unittests/test_layers.py | 9 + .../tests/unittests/test_yolov3_loss_op.py | 42 +-- 7 files changed, 327 insertions(+), 98 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index de32a5d5a2..8344a913e9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -183,6 +183,7 @@ paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', ' paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 7369ce31e8..cf25e99505 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -20,8 +20,6 @@ using framework::Tensor; class Yolov3LossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of Yolov3LossOp should not be null."); @@ -32,7 +30,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_x = ctx->GetInputDim("X"); auto dim_gt = ctx->GetInputDim("GTBox"); - auto img_height = ctx->Attrs().Get("img_height"); auto anchors = ctx->Attrs().Get>("anchors"); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); @@ -43,8 +40,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "+ class_num))."); PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5"); - PADDLE_ENFORCE_GT(img_height, 0, - "Attr(img_height) value should be greater then 0"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -87,13 +82,43 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("anchors", "The anchor width and height, " "it will be parsed pair by pair."); - AddAttr("img_height", - "The input image height after crop of yolov3 network."); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. + + The output of previous network is in shape [N, C, H, W], while H and W + should be the same, specify the grid size, each grid point predict given + number boxes, this given number is specified by anchors, it should be + half anchors length, which following will be represented as S. In the + second dimention(the channel dimention), C should be S * (class_num + 5), + class_num is the box categoriy number of source dataset(such as coco), + so in the second dimention, stores 4 box location coordinates x, y, w, h + and confidence score of the box and class one-hot key of each anchor box. + + While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions + correspnd to: + + $$ + b_x = \sigma(t_x) + c_x + b_y = \sigma(t_y) + c_y + b_w = p_w e^{t_w} + b_h = p_h e^{t_h} + $$ + + While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$ + is specified by anchors. + + As for confidence score, it is the logistic regression value of IoU between + anchor boxes and ground truth boxes, the score of the anchor box which has + the max IoU should be 1, and if the anchor box has IoU bigger then ignore + thresh, the confidence score loss of this anchor box will be ignored. + + Therefore, the yolov3 loss consist of three major parts, box location loss, + confidence score loss, and classification loss. The MSE loss is used for + box location, and binary cross entropy loss is used for confidence score + loss and classification loss. )DOC"); } }; @@ -101,8 +126,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { class Yolov3LossOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), @@ -113,6 +136,7 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { } } + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -120,12 +144,32 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { } }; +class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("yolov3_loss_grad"); + op->SetInput("X", Input("X")); + op->SetInput("GTBox", Input("GTBox")); + op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("GTBox"), {}); + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); REGISTER_OP_CPU_KERNEL( yolov3_loss, diff --git a/paddle/fluid/operators/yolov3_loss_op.cu b/paddle/fluid/operators/yolov3_loss_op.cu index 48f997456a..f901b10d38 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cu +++ b/paddle/fluid/operators/yolov3_loss_op.cu @@ -17,7 +17,7 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( yolov3_loss, - ops::Yolov3LossOpKernel); + ops::Yolov3LossKernel); REGISTER_OP_CUDA_KERNEL( yolov3_loss_grad, - ops::Yolov3LossGradOpKernel); + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 426e0688ab..a2ed4440a7 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -33,10 +33,22 @@ static inline bool isZero(T x) { } template -static inline T sigmod(T x) { +static inline T sigmoid(T x) { return 1.0 / (exp(-1.0 * x) + 1.0); } +template +static inline T CalcMaskPointNum(const Tensor& mask) { + auto mask_t = EigenVector::Flatten(mask); + T count = 0.0; + for (int i = 0; i < mask_t.dimensions()[0]; i++) { + if (mask_t(i)) { + count += 1.0; + } + } + return count; +} + template static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { @@ -55,6 +67,21 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, return (error_sum / points); } +template +static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y, + const Tensor& mask, T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf; + } + } +} + template static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, const Tensor& mask) { @@ -75,21 +102,34 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, } template -static void CalcPredResult(const Tensor& input, Tensor* pred_confs, - Tensor* pred_classes, Tensor* pred_x, Tensor* pred_y, - Tensor* pred_w, Tensor* pred_h, - std::vector anchors, const int class_num, - const int stride) { +static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x, + const Tensor& y, const Tensor& mask, + T mf) { + auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); + auto x_t = EigenVector::Flatten(x); + auto y_t = EigenVector::Flatten(y); + auto mask_t = EigenVector::Flatten(mask); + + for (int i = 0; i < x_t.dimensions()[0]; i++) { + if (mask_t(i)) { + grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf; + } + } +} + +template +static void CalcPredResult(const Tensor& input, Tensor* pred_conf, + Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, + Tensor* pred_w, Tensor* pred_h, const int anchor_num, + const int class_num) { const int n = input.dims()[0]; - const int c = input.dims()[1]; const int h = input.dims()[2]; const int w = input.dims()[3]; - const int anchor_num = anchors.size() / 2; const int box_attr_num = 5 + class_num; auto input_t = EigenTensor::From(input); - auto pred_confs_t = EigenTensor::From(*pred_confs); - auto pred_classes_t = EigenTensor::From(*pred_classes); + auto pred_conf_t = EigenTensor::From(*pred_conf); + auto pred_class_t = EigenTensor::From(*pred_class); auto pred_x_t = EigenTensor::From(*pred_x); auto pred_y_t = EigenTensor::From(*pred_y); auto pred_w_t = EigenTensor::From(*pred_w); @@ -97,26 +137,23 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs, for (int i = 0; i < n; i++) { for (int an_idx = 0; an_idx < anchor_num; an_idx++) { - float an_w = anchors[an_idx * 2] / stride; - float an_h = anchors[an_idx * 2 + 1] / stride; - for (int j = 0; j < h; j++) { for (int k = 0; k < w; k++) { pred_x_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx, j, k)); + sigmoid(input_t(i, box_attr_num * an_idx, j, k)); pred_y_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx + 1, j, k)); + sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k)); pred_w_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 2, j, k); pred_h_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 3, j, k); - pred_confs_t(i, an_idx, j, k) = - sigmod(input_t(i, box_attr_num * an_idx + 4, j, k)); + pred_conf_t(i, an_idx, j, k) = + sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k)); for (int c = 0; c < class_num; c++) { - pred_classes_t(i, an_idx, j, k, c) = - sigmod(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + pred_class_t(i, an_idx, j, k, c) = + sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); } } } @@ -148,27 +185,11 @@ static T CalcBoxIoU(std::vector box1, std::vector box2) { return inter_area / (b1_area + b2_area - inter_area); } -template -static inline int GetPredLabel(const Tensor& pred_classes, int n, - int best_an_index, int gj, int gi) { - auto pred_classes_t = EigenTensor::From(pred_classes); - T score = 0.0; - int label = -1; - for (int i = 0; i < pred_classes.dims()[4]; i++) { - if (pred_classes_t(n, best_an_index, gj, gi, i) > score) { - score = pred_classes_t(n, best_an_index, gj, gi, i); - label = i; - } - } - return label; -} - template static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, - std::vector anchors, const int img_height, - const int grid_size, Tensor* obj_mask, - Tensor* noobj_mask, Tensor* tx, Tensor* ty, - Tensor* tw, Tensor* th, Tensor* tconf, + std::vector anchors, const int grid_size, + Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, + Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf, Tensor* tclass) { const int n = gt_boxes.dims()[0]; const int b = gt_boxes.dims()[1]; @@ -240,6 +261,61 @@ static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, .broadcast(Array5(1, 1, 1, 1, class_num)); } +template +static void AddAllGradToInputGrad( + Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y, + const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, + const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, + const Tensor& grad_conf_obj, const Tensor& grad_conf_noobj, + const Tensor& grad_class, const int class_num) { + const int n = pred_x.dims()[0]; + const int an_num = pred_x.dims()[1]; + const int h = pred_x.dims()[2]; + const int w = pred_x.dims()[3]; + const int attr_num = class_num + 5; + auto grad_t = EigenTensor::From(*grad).setConstant(0.0); + auto pred_x_t = EigenTensor::From(pred_x); + auto pred_y_t = EigenTensor::From(pred_y); + auto pred_conf_t = EigenTensor::From(pred_conf); + auto pred_class_t = EigenTensor::From(pred_class); + auto grad_x_t = EigenTensor::From(grad_x); + auto grad_y_t = EigenTensor::From(grad_y); + auto grad_w_t = EigenTensor::From(grad_w); + auto grad_h_t = EigenTensor::From(grad_h); + auto grad_conf_obj_t = EigenTensor::From(grad_conf_obj); + auto grad_conf_noobj_t = EigenTensor::From(grad_conf_noobj); + auto grad_class_t = EigenTensor::From(grad_class); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * + pred_x_t(i, j, k, l) * + (1.0 - pred_x_t(i, j, k, l)) * loss; + grad_t(i, j * attr_num + 1, k, l) = + grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * + (1.0 - pred_y_t(i, j, k, l)) * loss; + grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss; + grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss; + grad_t(i, j * attr_num + 4, k, l) = + grad_conf_obj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss; + grad_t(i, j * attr_num + 4, k, l) += + grad_conf_noobj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss; + + for (int c = 0; c < class_num; c++) { + grad_t(i, j * attr_num + 5 + c, k, l) = + grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * + (1.0 - pred_class_t(i, j, k, l, c)) * loss; + } + } + } + } + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -247,28 +323,25 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_boxes = ctx.Input("GTBox"); auto* loss = ctx.Output("Loss"); - int img_height = ctx.Attr("img_height"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); const int n = input->dims()[0]; - const int c = input->dims()[1]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - const T stride = static_cast(img_height) / h; Tensor pred_x, pred_y, pred_w, pred_h; - Tensor pred_confs, pred_classes; + Tensor pred_conf, pred_class; pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_confs.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_classes.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcPredResult(*input, &pred_confs, &pred_classes, &pred_x, &pred_y, - &pred_w, &pred_h, anchors, class_num, stride); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; Tensor tx, ty, tw, th, tconf, tclass; @@ -280,9 +353,8 @@ class Yolov3LossKernel : public framework::OpKernel { th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, img_height, h, - &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, - &tclass); + PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); Tensor obj_mask_expand; obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, @@ -293,17 +365,9 @@ class Yolov3LossKernel : public framework::OpKernel { T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_obj = CalcBCEWithMask(pred_confs, tconf, obj_mask); - T loss_conf_noobj = CalcBCEWithMask(pred_confs, tconf, noobj_mask); - T loss_class = CalcBCEWithMask(pred_classes, tclass, obj_mask_expand); - - // LOG(ERROR) << "loss_x: " << loss_x; - // LOG(ERROR) << "loss_y: " << loss_y; - // LOG(ERROR) << "loss_w: " << loss_w; - // LOG(ERROR) << "loss_h: " << loss_h; - // LOG(ERROR) << "loss_conf_obj: " << loss_conf_obj; - // LOG(ERROR) << "loss_conf_noobj: " << loss_conf_noobj; - // LOG(ERROR) << "loss_class: " << loss_class; + T loss_conf_obj = CalcBCEWithMask(pred_conf, tconf, obj_mask); + T loss_conf_noobj = CalcBCEWithMask(pred_conf, tconf, noobj_mask); + T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_obj + @@ -315,8 +379,76 @@ template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* d_input_t = ctx.Output(framework::GradVarName("X")); - auto* d_output_t = ctx.Input(framework::GradVarName("Out")); + auto* input = ctx.Input("X"); + auto* gt_boxes = ctx.Input("GTBox"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float ignore_thresh = ctx.Attr("ignore_thresh"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Loss")); + const T loss = output_grad->data()[0]; + + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int an_num = anchors.size() / 2; + + Tensor pred_x, pred_y, pred_w, pred_h; + Tensor pred_conf, pred_class; + pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); + + Tensor obj_mask, noobj_mask; + Tensor tx, ty, tw, th, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + + Tensor obj_mask_expand; + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + + Tensor grad_x, grad_y, grad_w, grad_h; + Tensor grad_conf_obj, grad_conf_noobj, grad_class; + grad_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_obj.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_noobj.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + T obj_mf = CalcMaskPointNum(obj_mask); + T noobj_mf = CalcMaskPointNum(noobj_mask); + T obj_expand_mf = CalcMaskPointNum(obj_mask_expand); + CalcMSEGradWithMask(&grad_x, pred_x, tx, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); + CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); + CalcBCEGradWithMask(&grad_conf_obj, pred_conf, tconf, obj_mask, obj_mf); + CalcBCEGradWithMask(&grad_conf_noobj, pred_conf, tconf, noobj_mask, + noobj_mf); + CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, + obj_expand_mf); + + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + AddAllGradToInputGrad( + input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, + grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1ee7198f29..a4efb16682 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8244,14 +8244,55 @@ def log_loss(input, label, epsilon=1e-4, name=None): return loss -def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None): +@templatedoc(op_type="yolov3_loss") +def yolov3_loss(x, gtbox, anchors, class_num, ignore_thresh, name=None): """ - **YOLOv3 Loss Layer** + ${comment} + + Args: + x (Variable): ${x_comment} + gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5], + in the third dimenstion, class_id, x, y, w, h should + be stored and x, y, w, h should be relative valud of + input image. + anchors (list|tuple): ${anchors_comment} + class_num (int): ${class_num_comment} + ignore_thresh (float): ${ignore_thresh_comment} + name (string): the name of yolov3 loss - This layer + Returns: + Variable: A 1-D tensor with shape [1], the value of yolov3 loss + + Raises: + TypeError: Input x of yolov3_loss must be Variable + TypeError: Input gtbox of yolov3_loss must be Variable" + TypeError: Attr anchors of yolov3_loss must be list or tuple + TypeError: Attr class_num of yolov3_loss must be an integer + TypeError: Attr ignore_thresh of yolov3_loss must be a float number + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32') + gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32') + anchors = [10, 13, 16, 30, 33, 23] + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 + anchors=anchors, ignore_thresh=0.5) """ helper = LayerHelper('yolov3_loss', **locals()) + if not isinstance(x, Variable): + raise TypeError("Input x of yolov3_loss must be Variable") + if not isinstance(gtbox, Variable): + raise TypeError("Input gtbox of yolov3_loss must be Variable") + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(ignore_thresh, float): + raise TypeError( + "Attr ignore_thresh of yolov3_loss must be a float number") + if name is None: loss = helper.create_variable_for_type_inference(dtype=x.dtype) else: @@ -8264,8 +8305,8 @@ def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None): "GTBox": gtbox}, outputs={'Loss': loss}, attrs={ - "img_height": img_height, "anchors": anchors, + "class_num": class_num, "ignore_thresh": ignore_thresh, }) return loss diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index f48d9c84f9..dd02968c30 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -911,6 +911,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(data_1) print(str(program)) + def test_yolov3_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + gtbox = layers.data(name='gtbox', shape=[10, 5], dtype='float32') + loss = layers.yolov3_loss(x, gtbox, [10, 13, 30, 13], 10, 0.5) + + self.assertIsNotNone(loss) + def test_bilinear_tensor_product_layer(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index f5b15efb27..4562f8bd49 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import division + import unittest import numpy as np from op_test import OpTest +from paddle.fluid import core + def sigmoid(x): return 1.0 / (1.0 + np.exp(-1.0 * x)) @@ -65,10 +69,9 @@ def box_iou(box1, box2): def build_target(gtboxs, attrs, grid_size): n, b, _ = gtboxs.shape ignore_thresh = attrs["ignore_thresh"] - img_height = attrs["img_height"] anchors = attrs["anchors"] class_num = attrs["class_num"] - an_num = len(anchors) / 2 + an_num = len(anchors) // 2 obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') @@ -120,7 +123,7 @@ def build_target(gtboxs, attrs, grid_size): def YoloV3Loss(x, gtbox, attrs): n, c, h, w = x.shape - an_num = len(attrs['anchors']) / 2 + an_num = len(attrs['anchors']) // 2 class_num = attrs["class_num"] x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) pred_x = sigmoid(x[:, :, :, :, 0]) @@ -144,13 +147,6 @@ def YoloV3Loss(x, gtbox, attrs): noobj_mask) loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, obj_mask_expand) - # print "loss_x: ", loss_x - # print "loss_y: ", loss_y - # print "loss_w: ", loss_w - # print "loss_h: ", loss_h - # print "loss_conf_obj: ", loss_conf_obj - # print "loss_conf_noobj: ", loss_conf_noobj - # print "loss_class: ", loss_class return loss_x + loss_y + loss_w + loss_h + loss_conf_obj + loss_conf_noobj + loss_class @@ -165,29 +161,35 @@ class TestYolov3LossOp(OpTest): self.gtbox_shape[:2]) self.attrs = { - "img_height": self.img_height, "anchors": self.anchors, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, } self.inputs = {'X': x, 'GTBox': gtbox} - self.outputs = {'Loss': np.array([YoloV3Loss(x, gtbox, self.attrs)])} - print self.outputs + self.outputs = { + 'Loss': + np.array([YoloV3Loss(x, gtbox, self.attrs)]).astype('float32') + } def test_check_output(self): - self.check_output(atol=1e-3) + place = core.CPUPlace() + self.check_output_with_place(place, atol=1e-3) - # def test_check_grad_normal(self): - # self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61) + def test_check_grad_ignore_gtbox(self): + place = core.CPUPlace() + self.check_grad_with_place( + place, ['X'], + 'Loss', + no_grad_set=set("GTBox"), + max_relative_error=0.1) def initTestCase(self): - self.img_height = 608 - self.anchors = [10, 13, 16, 30, 33, 23] + self.anchors = [10, 13, 12, 12] self.class_num = 10 self.ignore_thresh = 0.5 - self.x_shape = (5, len(self.anchors) / 2 * (5 + self.class_num), 7, 7) - self.gtbox_shape = (5, 10, 5) + self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) + self.gtbox_shape = (5, 5, 5) if __name__ == "__main__": From 2faa2b4048d14e24acd3f8a3f8c55c2f492d0285 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 13 Nov 2018 20:08:54 +0800 Subject: [PATCH 5/8] remove cu file. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 36 ++++++- paddle/fluid/operators/yolov3_loss_op.cu | 23 ----- paddle/fluid/operators/yolov3_loss_op.h | 43 +++++--- python/paddle/fluid/layers/detection.py | 98 +++++++++++++++++++ python/paddle/fluid/layers/nn.py | 69 ------------- .../tests/unittests/test_yolov3_loss_op.py | 23 ++++- 7 files changed, 182 insertions(+), 112 deletions(-) delete mode 100644 paddle/fluid/operators/yolov3_loss_op.cu diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 8344a913e9..7e0d5e6088 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -183,7 +183,6 @@ paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', ' paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) @@ -289,6 +288,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'lambda_xy', 'lambda_wh', 'lambda_conf_obj', 'lambda_conf_noobj', 'lambda_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index cf25e99505..f6c134e1b4 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -55,7 +55,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); } }; @@ -63,8 +64,11 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "The input tensor of bilinear interpolation, " - "This is a 4-D tensor with shape of [N, C, H, W]"); + "The input tensor of YOLO v3 loss operator, " + "This is a 4-D tensor with shape of [N, C, H, W]." + "H and W should be same, and the second dimention(C) stores" + "box locations, confidence score and classification one-hot" + "key of each anchor box"); AddInput("GTBox", "The input tensor of ground truth boxes, " "This is a 3-D tensor with shape of [N, max_box_num, 5], " @@ -84,6 +88,20 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "it will be parsed pair by pair."); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); + AddAttr("lambda_xy", "The weight of x, y location loss.") + .SetDefault(1.0); + AddAttr("lambda_wh", "The weight of w, h location loss.") + .SetDefault(1.0); + AddAttr( + "lambda_conf_obj", + "The weight of confidence score loss in locations with target object.") + .SetDefault(1.0); + AddAttr("lambda_conf_noobj", + "The weight of confidence score loss in locations without " + "target object.") + .SetDefault(1.0); + AddAttr("lambda_class", "The weight of classification loss.") + .SetDefault(1.0); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. @@ -119,6 +137,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { confidence score loss, and classification loss. The MSE loss is used for box location, and binary cross entropy loss is used for confidence score loss and classification loss. + + Final loss will be represented as follow. + + $$ + loss = \lambda_{xy} * loss_{xy} + \lambda_{wh} * loss_{wh} + + \lambda_{conf_obj} * loss_{conf_obj} + + \lambda_{conf_noobj} * loss_{conf_noobj} + + \lambda_{class} * loss_{class} + $$ )DOC"); } }; @@ -140,7 +167,8 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.cu b/paddle/fluid/operators/yolov3_loss_op.cu deleted file mode 100644 index f901b10d38..0000000000 --- a/paddle/fluid/operators/yolov3_loss_op.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#define EIGEN_USE_GPU - -#include "paddle/fluid/operators/yolov3_loss_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - yolov3_loss, - ops::Yolov3LossKernel); -REGISTER_OP_CUDA_KERNEL( - yolov3_loss_grad, - ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index a2ed4440a7..f4ede92589 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -267,7 +267,9 @@ static void AddAllGradToInputGrad( const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, const Tensor& grad_conf_obj, const Tensor& grad_conf_noobj, - const Tensor& grad_class, const int class_num) { + const Tensor& grad_class, const int class_num, const float lambda_xy, + const float lambda_wh, const float lambda_conf_obj, + const float lambda_conf_noobj, const float lambda_class) { const int n = pred_x.dims()[0]; const int an_num = pred_x.dims()[1]; const int h = pred_x.dims()[2]; @@ -290,25 +292,27 @@ static void AddAllGradToInputGrad( for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * - pred_x_t(i, j, k, l) * - (1.0 - pred_x_t(i, j, k, l)) * loss; + grad_t(i, j * attr_num, k, l) = + grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) * + (1.0 - pred_x_t(i, j, k, l)) * loss * lambda_xy; grad_t(i, j * attr_num + 1, k, l) = grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * - (1.0 - pred_y_t(i, j, k, l)) * loss; - grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss; - grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss; + (1.0 - pred_y_t(i, j, k, l)) * loss * lambda_xy; + grad_t(i, j * attr_num + 2, k, l) = + grad_w_t(i, j, k, l) * loss * lambda_wh; + grad_t(i, j * attr_num + 3, k, l) = + grad_h_t(i, j, k, l) * loss * lambda_wh; grad_t(i, j * attr_num + 4, k, l) = grad_conf_obj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss; + (1.0 - pred_conf_t(i, j, k, l)) * loss * lambda_conf_obj; grad_t(i, j * attr_num + 4, k, l) += grad_conf_noobj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss; + (1.0 - pred_conf_t(i, j, k, l)) * loss * lambda_conf_noobj; for (int c = 0; c < class_num; c++) { grad_t(i, j * attr_num + 5 + c, k, l) = grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * - (1.0 - pred_class_t(i, j, k, l, c)) * loss; + (1.0 - pred_class_t(i, j, k, l, c)) * loss * lambda_class; } } } @@ -326,6 +330,11 @@ class Yolov3LossKernel : public framework::OpKernel { auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); + float lambda_xy = ctx.Attr("lambda_xy"); + float lambda_wh = ctx.Attr("lambda_wh"); + float lambda_conf_obj = ctx.Attr("lambda_conf_obj"); + float lambda_conf_noobj = ctx.Attr("lambda_conf_noobj"); + float lambda_class = ctx.Attr("lambda_class"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -370,8 +379,10 @@ class Yolov3LossKernel : public framework::OpKernel { T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); - loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_obj + - loss_conf_noobj + loss_class; + loss_data[0] = + lambda_xy * (loss_x + loss_y) + lambda_wh * (loss_w + loss_h) + + lambda_conf_obj * loss_conf_obj + lambda_conf_noobj * loss_conf_noobj + + lambda_class * loss_class; } }; @@ -387,6 +398,11 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Loss")); const T loss = output_grad->data()[0]; + float lambda_xy = ctx.Attr("lambda_xy"); + float lambda_wh = ctx.Attr("lambda_wh"); + float lambda_conf_obj = ctx.Attr("lambda_conf_obj"); + float lambda_conf_noobj = ctx.Attr("lambda_conf_noobj"); + float lambda_class = ctx.Attr("lambda_class"); const int n = input->dims()[0]; const int c = input->dims()[1]; @@ -448,7 +464,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); AddAllGradToInputGrad( input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, - grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num); + grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num, + lambda_xy, lambda_wh, lambda_conf_obj, lambda_conf_noobj, lambda_class); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 4ac94981a7..2bb9514803 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -20,6 +20,7 @@ from __future__ import print_function from .layer_function_generator import generate_layer_fn from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper +from ..framework import Variable from . import tensor from . import nn from . import ops @@ -45,6 +46,7 @@ __all__ = [ 'iou_similarity', 'box_coder', 'polygon_box_transform', + 'yolov3_loss', ] @@ -404,6 +406,102 @@ def polygon_box_transform(input, name=None): return output +@templatedoc(op_type="yolov3_loss") +def yolov3_loss(x, + gtbox, + anchors, + class_num, + ignore_thresh, + lambda_xy=None, + lambda_wh=None, + lambda_conf_obj=None, + lambda_conf_noobj=None, + lambda_class=None, + name=None): + """ + ${comment} + + Args: + x (Variable): ${x_comment} + gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5], + in the third dimenstion, class_id, x, y, w, h should + be stored and x, y, w, h should be relative valud of + input image. + anchors (list|tuple): ${anchors_comment} + class_num (int): ${class_num_comment} + ignore_thresh (float): ${ignore_thresh_comment} + lambda_xy (float|None): ${lambda_xy_comment} + lambda_wh (float|None): ${lambda_wh_comment} + lambda_conf_obj (float|None): ${lambda_conf_obj_comment} + lambda_conf_noobj (float|None): ${lambda_conf_noobj_comment} + lambda_class (float|None): ${lambda_class_comment} + name (string): the name of yolov3 loss + + Returns: + Variable: A 1-D tensor with shape [1], the value of yolov3 loss + + Raises: + TypeError: Input x of yolov3_loss must be Variable + TypeError: Input gtbox of yolov3_loss must be Variable" + TypeError: Attr anchors of yolov3_loss must be list or tuple + TypeError: Attr class_num of yolov3_loss must be an integer + TypeError: Attr ignore_thresh of yolov3_loss must be a float number + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32') + gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32') + anchors = [10, 13, 16, 30, 33, 23] + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 + anchors=anchors, ignore_thresh=0.5) + """ + helper = LayerHelper('yolov3_loss', **locals()) + + if not isinstance(x, Variable): + raise TypeError("Input x of yolov3_loss must be Variable") + if not isinstance(gtbox, Variable): + raise TypeError("Input gtbox of yolov3_loss must be Variable") + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(ignore_thresh, float): + raise TypeError( + "Attr ignore_thresh of yolov3_loss must be a float number") + + if name is None: + loss = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + loss = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + attrs = { + "anchors": anchors, + "class_num": class_num, + "ignore_thresh": ignore_thresh, + } + + if lambda_xy is not None and isinstance(lambda_xy, float): + self.attrs['lambda_xy'] = lambda_xy + if lambda_wh is not None and isinstance(lambda_wh, float): + self.attrs['lambda_wh'] = lambda_wh + if lambda_conf_obj is not None and isinstance(lambda_conf_obj, float): + self.attrs['lambda_conf_obj'] = lambda_conf_obj + if lambda_conf_noobj is not None and isinstance(lambda_conf_noobj, float): + self.attrs['lambda_conf_noobj'] = lambda_conf_noobj + if lambda_class is not None and isinstance(lambda_class, float): + self.attrs['lambda_class'] = lambda_class + + helper.append_op( + type='yolov3_loss', + inputs={'X': x, + "GTBox": gtbox}, + outputs={'Loss': loss}, + attrs=attrs) + return loss + + @templatedoc() def detection_map(detect_res, label, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a4efb16682..d3623464e9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -164,7 +164,6 @@ __all__ = [ 'hash', 'grid_sampler', 'log_loss', - 'yolov3_loss', 'add_position_encoding', 'bilinear_tensor_product', ] @@ -8244,74 +8243,6 @@ def log_loss(input, label, epsilon=1e-4, name=None): return loss -@templatedoc(op_type="yolov3_loss") -def yolov3_loss(x, gtbox, anchors, class_num, ignore_thresh, name=None): - """ - ${comment} - - Args: - x (Variable): ${x_comment} - gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5], - in the third dimenstion, class_id, x, y, w, h should - be stored and x, y, w, h should be relative valud of - input image. - anchors (list|tuple): ${anchors_comment} - class_num (int): ${class_num_comment} - ignore_thresh (float): ${ignore_thresh_comment} - name (string): the name of yolov3 loss - - Returns: - Variable: A 1-D tensor with shape [1], the value of yolov3 loss - - Raises: - TypeError: Input x of yolov3_loss must be Variable - TypeError: Input gtbox of yolov3_loss must be Variable" - TypeError: Attr anchors of yolov3_loss must be list or tuple - TypeError: Attr class_num of yolov3_loss must be an integer - TypeError: Attr ignore_thresh of yolov3_loss must be a float number - - Examples: - .. code-block:: python - - x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32') - gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32') - anchors = [10, 13, 16, 30, 33, 23] - loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 - anchors=anchors, ignore_thresh=0.5) - """ - helper = LayerHelper('yolov3_loss', **locals()) - - if not isinstance(x, Variable): - raise TypeError("Input x of yolov3_loss must be Variable") - if not isinstance(gtbox, Variable): - raise TypeError("Input gtbox of yolov3_loss must be Variable") - if not isinstance(anchors, list) and not isinstance(anchors, tuple): - raise TypeError("Attr anchors of yolov3_loss must be list or tuple") - if not isinstance(class_num, int): - raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(ignore_thresh, float): - raise TypeError( - "Attr ignore_thresh of yolov3_loss must be a float number") - - if name is None: - loss = helper.create_variable_for_type_inference(dtype=x.dtype) - else: - loss = helper.create_variable( - name=name, dtype=x.dtype, persistable=False) - - helper.append_op( - type='yolov3_loss', - inputs={'X': x, - "GTBox": gtbox}, - outputs={'Loss': loss}, - attrs={ - "anchors": anchors, - "class_num": class_num, - "ignore_thresh": ignore_thresh, - }) - return loss - - def add_position_encoding(input, alpha, beta, name=None): """ **Add Position Encoding Layer** diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 4562f8bd49..3b6d58563f 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -148,11 +148,20 @@ def YoloV3Loss(x, gtbox, attrs): loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, obj_mask_expand) - return loss_x + loss_y + loss_w + loss_h + loss_conf_obj + loss_conf_noobj + loss_class + return attrs['lambda_xy'] * (loss_x + loss_y) \ + + attrs['lambda_wh'] * (loss_w + loss_h) \ + + attrs['lambda_conf_obj'] * loss_conf_obj \ + + attrs['lambda_conf_noobj'] * loss_conf_noobj \ + + attrs['lambda_class'] * loss_class class TestYolov3LossOp(OpTest): def setUp(self): + self.lambda_xy = 1.0 + self.lambda_wh = 1.0 + self.lambda_conf_obj = 1.0 + self.lambda_conf_noobj = 1.0 + self.lambda_class = 1.0 self.initTestCase() self.op_type = 'yolov3_loss' x = np.random.random(size=self.x_shape).astype('float32') @@ -164,6 +173,11 @@ class TestYolov3LossOp(OpTest): "anchors": self.anchors, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, + "lambda_xy": self.lambda_xy, + "lambda_wh": self.lambda_wh, + "lambda_conf_obj": self.lambda_conf_obj, + "lambda_conf_noobj": self.lambda_conf_noobj, + "lambda_class": self.lambda_class, } self.inputs = {'X': x, 'GTBox': gtbox} @@ -182,7 +196,7 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set("GTBox"), - max_relative_error=0.1) + max_relative_error=0.06) def initTestCase(self): self.anchors = [10, 13, 12, 12] @@ -190,6 +204,11 @@ class TestYolov3LossOp(OpTest): self.ignore_thresh = 0.5 self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) self.gtbox_shape = (5, 5, 5) + self.lambda_xy = 2.5 + self.lambda_wh = 0.8 + self.lambda_conf_obj = 1.5 + self.lambda_conf_noobj = 0.5 + self.lambda_class = 1.2 if __name__ == "__main__": From 95d5060dddcbfd0eff8cb50d542f5adb6899b6b6 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 15 Nov 2018 18:57:49 +0800 Subject: [PATCH 6/8] fix abs -> fabs error. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 13 +++++++------ .../fluid/tests/unittests/test_yolov3_loss_op.py | 14 +++++++------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index f4ede92589..608ef3f94b 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -29,7 +29,7 @@ using Array5 = Eigen::DSizes; template static inline bool isZero(T x) { - return abs(x) < 1e-6; + return fabs(x) < 1e-6; } template @@ -186,7 +186,7 @@ static T CalcBoxIoU(std::vector box1, std::vector box2) { } template -static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, +static void PreProcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, std::vector anchors, const int grid_size, Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf, @@ -206,8 +206,9 @@ static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_boxes_t(i, j, 0)) && isZero(gt_boxes_t(i, j, 1)) && - isZero(gt_boxes_t(i, j, 2)) && isZero(gt_boxes_t(i, j, 3))) { + if (isZero(gt_boxes_t(i, j, 0)) && isZero(gt_boxes_t(i, j, 1)) && + isZero(gt_boxes_t(i, j, 2)) && isZero(gt_boxes_t(i, j, 3)) && + isZero(gt_boxes_t(i, j, 4))) { continue; } @@ -362,7 +363,7 @@ class Yolov3LossKernel : public framework::OpKernel { th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + PreProcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); Tensor obj_mask_expand; @@ -431,7 +432,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PrePorcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + PreProcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); Tensor obj_mask_expand; diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 3b6d58563f..03a64055f0 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -190,13 +190,13 @@ class TestYolov3LossOp(OpTest): place = core.CPUPlace() self.check_output_with_place(place, atol=1e-3) - def test_check_grad_ignore_gtbox(self): - place = core.CPUPlace() - self.check_grad_with_place( - place, ['X'], - 'Loss', - no_grad_set=set("GTBox"), - max_relative_error=0.06) + # def test_check_grad_ignore_gtbox(self): + # place = core.CPUPlace() + # self.check_grad_with_place( + # place, ['X'], + # 'Loss', + # no_grad_set=set("GTBox"), + # max_relative_error=0.06) def initTestCase(self): self.anchors = [10, 13, 12, 12] From f115eb0d1e6ffa1dd65bfcc7b30b419d52f3c68b Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 15 Nov 2018 21:05:28 +0800 Subject: [PATCH 7/8] enhance api. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 50 ++++--- paddle/fluid/operators/yolov3_loss_op.h | 129 ++++++++++-------- python/paddle/fluid/layers/detection.py | 67 +++++---- python/paddle/fluid/tests/test_detection.py | 13 ++ .../fluid/tests/unittests/test_layers.py | 9 -- .../tests/unittests/test_yolov3_loss_op.py | 88 ++++++------ 7 files changed, 199 insertions(+), 159 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 7e0d5e6088..1f1dc3757d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -288,7 +288,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'lambda_xy', 'lambda_wh', 'lambda_conf_obj', 'lambda_conf_noobj', 'lambda_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index f6c134e1b4..1d7f482362 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -25,11 +25,14 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(X) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("GTBox"), "Input(GTBox) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTLabel"), + "Input(GTLabel) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); auto dim_x = ctx->GetInputDim("X"); - auto dim_gt = ctx->GetInputDim("GTBox"); + auto dim_gtbox = ctx->GetInputDim("GTBox"); + auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto anchors = ctx->Attrs().Get>("anchors"); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); @@ -38,8 +41,15 @@ class Yolov3LossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), "Input(X) dim[1] should be equal to (anchor_number * (5 " "+ class_num))."); - PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor"); - PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5"); + PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, + "Input(GTBox) should be a 3-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); + PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2, + "Input(GTBox) should be a 2-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0], + "Input(GTBox) and Input(GTLabel) dim[0] should be same"); + PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], + "Input(GTBox) and Input(GTLabel) dim[1] should be same"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -73,11 +83,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "The input tensor of ground truth boxes, " "This is a 3-D tensor with shape of [N, max_box_num, 5], " "max_box_num is the max number of boxes in each image, " - "In the third dimention, stores label, x, y, w, h, " - "label is an integer to specify box class, x, y is the " - "center cordinate of boxes and w, h is the width and height" - "and x, y, w, h should be divided by input image height to " - "scale to [0, 1]."); + "In the third dimention, stores x, y, w, h coordinates, " + "x, y is the center cordinate of boxes and w, h is the " + "width and height and x, y, w, h should be divided by " + "input image height to scale to [0, 1]."); + AddInput("GTLabel", + "The input tensor of ground truth label, " + "This is a 2-D tensor with shape of [N, max_box_num], " + "and each element shoudl be an integer to indicate the " + "box class id."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [1]"); @@ -88,19 +102,19 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "it will be parsed pair by pair."); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); - AddAttr("lambda_xy", "The weight of x, y location loss.") + AddAttr("loss_weight_xy", "The weight of x, y location loss.") .SetDefault(1.0); - AddAttr("lambda_wh", "The weight of w, h location loss.") + AddAttr("loss_weight_wh", "The weight of w, h location loss.") .SetDefault(1.0); AddAttr( - "lambda_conf_obj", + "loss_weight_conf_target", "The weight of confidence score loss in locations with target object.") .SetDefault(1.0); - AddAttr("lambda_conf_noobj", + AddAttr("loss_weight_conf_notarget", "The weight of confidence score loss in locations without " "target object.") .SetDefault(1.0); - AddAttr("lambda_class", "The weight of classification loss.") + AddAttr("loss_weight_class", "The weight of classification loss.") .SetDefault(1.0); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground @@ -141,10 +155,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { Final loss will be represented as follow. $$ - loss = \lambda_{xy} * loss_{xy} + \lambda_{wh} * loss_{wh} - + \lambda_{conf_obj} * loss_{conf_obj} - + \lambda_{conf_noobj} * loss_{conf_noobj} - + \lambda_{class} * loss_{class} + loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh} + + \loss_weight_{conf_target} * loss_{conf_target} + + \loss_weight_{conf_notarget} * loss_{conf_notarget} + + \loss_weight_{class} * loss_{class} $$ )DOC"); } @@ -182,12 +196,14 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetType("yolov3_loss_grad"); op->SetInput("X", Input("X")); op->SetInput("GTBox", Input("GTBox")); + op->SetInput("GTLabel", Input("GTLabel")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetAttrMap(Attrs()); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("GTBox"), {}); + op->SetOutput(framework::GradVarName("GTLabel"), {}); return std::unique_ptr(op); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 608ef3f94b..a1072aca10 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -186,15 +186,17 @@ static T CalcBoxIoU(std::vector box1, std::vector box2) { } template -static void PreProcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, - std::vector anchors, const int grid_size, - Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, - Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf, +static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, + const float ignore_thresh, std::vector anchors, + const int grid_size, Tensor* obj_mask, + Tensor* noobj_mask, Tensor* tx, Tensor* ty, + Tensor* tw, Tensor* th, Tensor* tconf, Tensor* tclass) { - const int n = gt_boxes.dims()[0]; - const int b = gt_boxes.dims()[1]; + const int n = gt_box.dims()[0]; + const int b = gt_box.dims()[1]; const int anchor_num = anchors.size() / 2; - auto gt_boxes_t = EigenTensor::From(gt_boxes); + auto gt_box_t = EigenTensor::From(gt_box); + auto gt_label_t = EigenTensor::From(gt_label); auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); auto tx_t = EigenTensor::From(*tx).setConstant(0.0); @@ -206,28 +208,27 @@ static void PreProcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_boxes_t(i, j, 0)) && isZero(gt_boxes_t(i, j, 1)) && - isZero(gt_boxes_t(i, j, 2)) && isZero(gt_boxes_t(i, j, 3)) && - isZero(gt_boxes_t(i, j, 4))) { + if (isZero(gt_box_t(i, j, 0)) && isZero(gt_box_t(i, j, 1)) && + isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { continue; } - int gt_label = static_cast(gt_boxes_t(i, j, 0)); - T gx = gt_boxes_t(i, j, 1) * grid_size; - T gy = gt_boxes_t(i, j, 2) * grid_size; - T gw = gt_boxes_t(i, j, 3) * grid_size; - T gh = gt_boxes_t(i, j, 4) * grid_size; + int cur_label = gt_label_t(i, j); + T gx = gt_box_t(i, j, 0) * grid_size; + T gy = gt_box_t(i, j, 1) * grid_size; + T gw = gt_box_t(i, j, 2) * grid_size; + T gh = gt_box_t(i, j, 3) * grid_size; int gi = static_cast(gx); int gj = static_cast(gy); T max_iou = static_cast(0); T iou; int best_an_index = -1; - std::vector gt_box({0, 0, gw, gh}); + std::vector gt_box_shape({0, 0, gw, gh}); for (int an_idx = 0; an_idx < anchor_num; an_idx++) { std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), static_cast(anchors[2 * an_idx + 1])}); - iou = CalcBoxIoU(gt_box, anchor_shape); + iou = CalcBoxIoU(gt_box_shape, anchor_shape); if (iou > max_iou) { max_iou = iou; best_an_index = an_idx; @@ -242,7 +243,7 @@ static void PreProcessGTBox(const Tensor& gt_boxes, const float ignore_thresh, ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); - tclass_t(i, best_an_index, gj, gi, gt_label) = 1; + tclass_t(i, best_an_index, gj, gi, cur_label) = 1; tconf_t(i, best_an_index, gj, gi) = 1; } } @@ -267,10 +268,10 @@ static void AddAllGradToInputGrad( Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y, const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, - const Tensor& grad_conf_obj, const Tensor& grad_conf_noobj, - const Tensor& grad_class, const int class_num, const float lambda_xy, - const float lambda_wh, const float lambda_conf_obj, - const float lambda_conf_noobj, const float lambda_class) { + const Tensor& grad_conf_target, const Tensor& grad_conf_notarget, + const Tensor& grad_class, const int class_num, const float loss_weight_xy, + const float loss_weight_wh, const float loss_weight_conf_target, + const float loss_weight_conf_notarget, const float loss_weight_class) { const int n = pred_x.dims()[0]; const int an_num = pred_x.dims()[1]; const int h = pred_x.dims()[2]; @@ -285,8 +286,8 @@ static void AddAllGradToInputGrad( auto grad_y_t = EigenTensor::From(grad_y); auto grad_w_t = EigenTensor::From(grad_w); auto grad_h_t = EigenTensor::From(grad_h); - auto grad_conf_obj_t = EigenTensor::From(grad_conf_obj); - auto grad_conf_noobj_t = EigenTensor::From(grad_conf_noobj); + auto grad_conf_target_t = EigenTensor::From(grad_conf_target); + auto grad_conf_notarget_t = EigenTensor::From(grad_conf_notarget); auto grad_class_t = EigenTensor::From(grad_class); for (int i = 0; i < n; i++) { @@ -295,25 +296,26 @@ static void AddAllGradToInputGrad( for (int l = 0; l < w; l++) { grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) * - (1.0 - pred_x_t(i, j, k, l)) * loss * lambda_xy; + (1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy; grad_t(i, j * attr_num + 1, k, l) = grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * - (1.0 - pred_y_t(i, j, k, l)) * loss * lambda_xy; + (1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy; grad_t(i, j * attr_num + 2, k, l) = - grad_w_t(i, j, k, l) * loss * lambda_wh; + grad_w_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 3, k, l) = - grad_h_t(i, j, k, l) * loss * lambda_wh; + grad_h_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 4, k, l) = - grad_conf_obj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * lambda_conf_obj; + grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target; grad_t(i, j * attr_num + 4, k, l) += - grad_conf_noobj_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * lambda_conf_noobj; + grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) * + (1.0 - pred_conf_t(i, j, k, l)) * loss * + loss_weight_conf_notarget; for (int c = 0; c < class_num; c++) { grad_t(i, j * attr_num + 5 + c, k, l) = grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * - (1.0 - pred_class_t(i, j, k, l, c)) * loss * lambda_class; + (1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class; } } } @@ -326,16 +328,18 @@ class Yolov3LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); - auto* gt_boxes = ctx.Input("GTBox"); + auto* gt_box = ctx.Input("GTBox"); + auto* gt_label = ctx.Input("GTLabel"); auto* loss = ctx.Output("Loss"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); - float lambda_xy = ctx.Attr("lambda_xy"); - float lambda_wh = ctx.Attr("lambda_wh"); - float lambda_conf_obj = ctx.Attr("lambda_conf_obj"); - float lambda_conf_noobj = ctx.Attr("lambda_conf_noobj"); - float lambda_class = ctx.Attr("lambda_class"); + float loss_weight_xy = ctx.Attr("loss_weight_xy"); + float loss_weight_wh = ctx.Attr("loss_weight_wh"); + float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); + float loss_weight_conf_notarget = + ctx.Attr("loss_weight_conf_notarget"); + float loss_weight_class = ctx.Attr("loss_weight_class"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -363,7 +367,7 @@ class Yolov3LossKernel : public framework::OpKernel { th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); Tensor obj_mask_expand; @@ -375,15 +379,16 @@ class Yolov3LossKernel : public framework::OpKernel { T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_obj = CalcBCEWithMask(pred_conf, tconf, obj_mask); - T loss_conf_noobj = CalcBCEWithMask(pred_conf, tconf, noobj_mask); + T loss_conf_target = CalcBCEWithMask(pred_conf, tconf, obj_mask); + T loss_conf_notarget = CalcBCEWithMask(pred_conf, tconf, noobj_mask); T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); - loss_data[0] = - lambda_xy * (loss_x + loss_y) + lambda_wh * (loss_w + loss_h) + - lambda_conf_obj * loss_conf_obj + lambda_conf_noobj * loss_conf_noobj + - lambda_class * loss_class; + loss_data[0] = loss_weight_xy * (loss_x + loss_y) + + loss_weight_wh * (loss_w + loss_h) + + loss_weight_conf_target * loss_conf_target + + loss_weight_conf_notarget * loss_conf_notarget + + loss_weight_class * loss_class; } }; @@ -392,18 +397,20 @@ class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); - auto* gt_boxes = ctx.Input("GTBox"); + auto* gt_box = ctx.Input("GTBox"); + auto* gt_label = ctx.Input("GTLabel"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Loss")); const T loss = output_grad->data()[0]; - float lambda_xy = ctx.Attr("lambda_xy"); - float lambda_wh = ctx.Attr("lambda_wh"); - float lambda_conf_obj = ctx.Attr("lambda_conf_obj"); - float lambda_conf_noobj = ctx.Attr("lambda_conf_noobj"); - float lambda_class = ctx.Attr("lambda_class"); + float loss_weight_xy = ctx.Attr("loss_weight_xy"); + float loss_weight_wh = ctx.Attr("loss_weight_wh"); + float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); + float loss_weight_conf_notarget = + ctx.Attr("loss_weight_conf_notarget"); + float loss_weight_class = ctx.Attr("loss_weight_class"); const int n = input->dims()[0]; const int c = input->dims()[1]; @@ -432,7 +439,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_boxes, ignore_thresh, anchors, h, &obj_mask, + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); Tensor obj_mask_expand; @@ -441,13 +448,13 @@ class Yolov3LossGradKernel : public framework::OpKernel { ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); Tensor grad_x, grad_y, grad_w, grad_h; - Tensor grad_conf_obj, grad_conf_noobj, grad_class; + Tensor grad_conf_target, grad_conf_notarget, grad_class; grad_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_conf_obj.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_conf_noobj.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); T obj_mf = CalcMaskPointNum(obj_mask); T noobj_mf = CalcMaskPointNum(noobj_mask); @@ -456,8 +463,9 @@ class Yolov3LossGradKernel : public framework::OpKernel { CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); - CalcBCEGradWithMask(&grad_conf_obj, pred_conf, tconf, obj_mask, obj_mf); - CalcBCEGradWithMask(&grad_conf_noobj, pred_conf, tconf, noobj_mask, + CalcBCEGradWithMask(&grad_conf_target, pred_conf, tconf, obj_mask, + obj_mf); + CalcBCEGradWithMask(&grad_conf_notarget, pred_conf, tconf, noobj_mask, noobj_mf); CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, obj_expand_mf); @@ -465,8 +473,9 @@ class Yolov3LossGradKernel : public framework::OpKernel { input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); AddAllGradToInputGrad( input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, - grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num, - lambda_xy, lambda_wh, lambda_conf_obj, lambda_conf_noobj, lambda_class); + grad_w, grad_h, grad_conf_target, grad_conf_notarget, grad_class, + class_num, loss_weight_xy, loss_weight_wh, loss_weight_conf_target, + loss_weight_conf_notarget, loss_weight_class); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 2bb9514803..cab5c3e2a4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -409,32 +409,36 @@ def polygon_box_transform(input, name=None): @templatedoc(op_type="yolov3_loss") def yolov3_loss(x, gtbox, + gtlabel, anchors, class_num, ignore_thresh, - lambda_xy=None, - lambda_wh=None, - lambda_conf_obj=None, - lambda_conf_noobj=None, - lambda_class=None, + loss_weight_xy=None, + loss_weight_wh=None, + loss_weight_conf_target=None, + loss_weight_conf_notarget=None, + loss_weight_class=None, name=None): """ ${comment} Args: x (Variable): ${x_comment} - gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5], - in the third dimenstion, class_id, x, y, w, h should - be stored and x, y, w, h should be relative valud of - input image. + gtbox (Variable): groud truth boxes, should be in shape of [N, B, 4], + in the third dimenstion, x, y, w, h should be stored + and x, y, w, h should be relative value of input image. + N is the batch number and B is the max box number in + an image. + gtlabel (Variable): class id of ground truth boxes, shoud be ins shape + of [N, B]. anchors (list|tuple): ${anchors_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} - lambda_xy (float|None): ${lambda_xy_comment} - lambda_wh (float|None): ${lambda_wh_comment} - lambda_conf_obj (float|None): ${lambda_conf_obj_comment} - lambda_conf_noobj (float|None): ${lambda_conf_noobj_comment} - lambda_class (float|None): ${lambda_class_comment} + loss_weight_xy (float|None): ${loss_weight_xy_comment} + loss_weight_wh (float|None): ${loss_weight_wh_comment} + loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} + loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment} + loss_weight_class (float|None): ${loss_weight_class_comment} name (string): the name of yolov3 loss Returns: @@ -443,6 +447,7 @@ def yolov3_loss(x, Raises: TypeError: Input x of yolov3_loss must be Variable TypeError: Input gtbox of yolov3_loss must be Variable" + TypeError: Input gtlabel of yolov3_loss must be Variable" TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number @@ -450,8 +455,9 @@ def yolov3_loss(x, Examples: .. code-block:: python - x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32') - gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32') + x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') + gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') + gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') anchors = [10, 13, 16, 30, 33, 23] loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 anchors=anchors, ignore_thresh=0.5) @@ -462,6 +468,8 @@ def yolov3_loss(x, raise TypeError("Input x of yolov3_loss must be Variable") if not isinstance(gtbox, Variable): raise TypeError("Input gtbox of yolov3_loss must be Variable") + if not isinstance(gtlabel, Variable): + raise TypeError("Input gtlabel of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") if not isinstance(class_num, int): @@ -482,21 +490,24 @@ def yolov3_loss(x, "ignore_thresh": ignore_thresh, } - if lambda_xy is not None and isinstance(lambda_xy, float): - self.attrs['lambda_xy'] = lambda_xy - if lambda_wh is not None and isinstance(lambda_wh, float): - self.attrs['lambda_wh'] = lambda_wh - if lambda_conf_obj is not None and isinstance(lambda_conf_obj, float): - self.attrs['lambda_conf_obj'] = lambda_conf_obj - if lambda_conf_noobj is not None and isinstance(lambda_conf_noobj, float): - self.attrs['lambda_conf_noobj'] = lambda_conf_noobj - if lambda_class is not None and isinstance(lambda_class, float): - self.attrs['lambda_class'] = lambda_class + if loss_weight_xy is not None and isinstance(loss_weight_xy, float): + self.attrs['loss_weight_xy'] = loss_weight_xy + if loss_weight_wh is not None and isinstance(loss_weight_wh, float): + self.attrs['loss_weight_wh'] = loss_weight_wh + if loss_weight_conf_target is not None and isinstance( + loss_weight_conf_target, float): + self.attrs['loss_weight_conf_target'] = loss_weight_conf_target + if loss_weight_conf_notarget is not None and isinstance( + loss_weight_conf_notarget, float): + self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget + if loss_weight_class is not None and isinstance(loss_weight_class, float): + self.attrs['loss_weight_class'] = loss_weight_class helper.append_op( type='yolov3_loss', - inputs={'X': x, - "GTBox": gtbox}, + inputs={"X": x, + "GTBox": gtbox, + "GTLabel": gtlabel}, outputs={'Loss': loss}, attrs=attrs) return loss diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 28dc751957..527fd521d5 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -366,5 +366,18 @@ class TestGenerateProposals(unittest.TestCase): print(rpn_rois.shape) +class TestYoloDetection(unittest.TestCase): + def test_yolov3_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') + gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') + loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, + 0.5) + + self.assertIsNotNone(loss) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index dd02968c30..f48d9c84f9 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -911,15 +911,6 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(data_1) print(str(program)) - def test_yolov3_loss(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') - gtbox = layers.data(name='gtbox', shape=[10, 5], dtype='float32') - loss = layers.yolov3_loss(x, gtbox, [10, 13, 30, 13], 10, 0.5) - - self.assertIsNotNone(loss) - def test_bilinear_tensor_product_layer(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 03a64055f0..335214b298 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -66,7 +66,7 @@ def box_iou(box1, box2): return inter_area / (b1_area + b2_area + inter_area) -def build_target(gtboxs, attrs, grid_size): +def build_target(gtboxs, gtlabel, attrs, grid_size): n, b, _ = gtboxs.shape ignore_thresh = attrs["ignore_thresh"] anchors = attrs["anchors"] @@ -87,11 +87,11 @@ def build_target(gtboxs, attrs, grid_size): if gtboxs[i, j, :].sum() == 0: continue - gt_label = int(gtboxs[i, j, 0]) - gx = gtboxs[i, j, 1] * grid_size - gy = gtboxs[i, j, 2] * grid_size - gw = gtboxs[i, j, 3] * grid_size - gh = gtboxs[i, j, 4] * grid_size + gt_label = gtlabel[i, j] + gx = gtboxs[i, j, 0] * grid_size + gy = gtboxs[i, j, 1] * grid_size + gw = gtboxs[i, j, 2] * grid_size + gh = gtboxs[i, j, 3] * grid_size gi = int(gx) gj = int(gy) @@ -121,7 +121,7 @@ def build_target(gtboxs, attrs, grid_size): return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) -def YoloV3Loss(x, gtbox, attrs): +def YoloV3Loss(x, gtbox, gtlabel, attrs): n, c, h, w = x.shape an_num = len(attrs['anchors']) // 2 class_num = attrs["class_num"] @@ -134,7 +134,7 @@ def YoloV3Loss(x, gtbox, attrs): pred_cls = sigmoid(x[:, :, :, :, 5:]) tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( - gtbox, attrs, x.shape[2]) + gtbox, gtlabel, attrs, x.shape[2]) obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) @@ -142,73 +142,73 @@ def YoloV3Loss(x, gtbox, attrs): loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) - loss_conf_obj = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) - loss_conf_noobj = bce(pred_conf * noobj_mask, tconf * noobj_mask, - noobj_mask) + loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) + loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask, + noobj_mask) loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, obj_mask_expand) - return attrs['lambda_xy'] * (loss_x + loss_y) \ - + attrs['lambda_wh'] * (loss_w + loss_h) \ - + attrs['lambda_conf_obj'] * loss_conf_obj \ - + attrs['lambda_conf_noobj'] * loss_conf_noobj \ - + attrs['lambda_class'] * loss_class + return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + + attrs['loss_weight_wh'] * (loss_w + loss_h) \ + + attrs['loss_weight_conf_target'] * loss_conf_target \ + + attrs['loss_weight_conf_notarget'] * loss_conf_notarget \ + + attrs['loss_weight_class'] * loss_class class TestYolov3LossOp(OpTest): def setUp(self): - self.lambda_xy = 1.0 - self.lambda_wh = 1.0 - self.lambda_conf_obj = 1.0 - self.lambda_conf_noobj = 1.0 - self.lambda_class = 1.0 + self.loss_weight_xy = 1.0 + self.loss_weight_wh = 1.0 + self.loss_weight_conf_target = 1.0 + self.loss_weight_conf_notarget = 1.0 + self.loss_weight_class = 1.0 self.initTestCase() self.op_type = 'yolov3_loss' x = np.random.random(size=self.x_shape).astype('float32') gtbox = np.random.random(size=self.gtbox_shape).astype('float32') - gtbox[:, :, 0] = np.random.randint(0, self.class_num, - self.gtbox_shape[:2]) + gtlabel = np.random.randint(0, self.class_num, + self.gtbox_shape[:2]).astype('int32') self.attrs = { "anchors": self.anchors, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, - "lambda_xy": self.lambda_xy, - "lambda_wh": self.lambda_wh, - "lambda_conf_obj": self.lambda_conf_obj, - "lambda_conf_noobj": self.lambda_conf_noobj, - "lambda_class": self.lambda_class, + "loss_weight_xy": self.loss_weight_xy, + "loss_weight_wh": self.loss_weight_wh, + "loss_weight_conf_target": self.loss_weight_conf_target, + "loss_weight_conf_notarget": self.loss_weight_conf_notarget, + "loss_weight_class": self.loss_weight_class, } - self.inputs = {'X': x, 'GTBox': gtbox} + self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} self.outputs = { - 'Loss': - np.array([YoloV3Loss(x, gtbox, self.attrs)]).astype('float32') + 'Loss': np.array( + [YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32') } def test_check_output(self): place = core.CPUPlace() self.check_output_with_place(place, atol=1e-3) - # def test_check_grad_ignore_gtbox(self): - # place = core.CPUPlace() - # self.check_grad_with_place( - # place, ['X'], - # 'Loss', - # no_grad_set=set("GTBox"), - # max_relative_error=0.06) + def test_check_grad_ignore_gtbox(self): + place = core.CPUPlace() + self.check_grad_with_place( + place, ['X'], + 'Loss', + no_grad_set=set("GTBox"), + max_relative_error=0.06) def initTestCase(self): self.anchors = [10, 13, 12, 12] self.class_num = 10 self.ignore_thresh = 0.5 self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) - self.gtbox_shape = (5, 5, 5) - self.lambda_xy = 2.5 - self.lambda_wh = 0.8 - self.lambda_conf_obj = 1.5 - self.lambda_conf_noobj = 0.5 - self.lambda_class = 1.2 + self.gtbox_shape = (5, 10, 4) + self.loss_weight_xy = 2.5 + self.loss_weight_wh = 0.8 + self.loss_weight_conf_target = 1.5 + self.loss_weight_conf_notarget = 0.5 + self.loss_weight_class = 1.2 if __name__ == "__main__": From 8ef6280c034602f776554432672d42b826afbaee Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 16 Nov 2018 19:14:40 +0800 Subject: [PATCH 8/8] Add operator double support. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 10 ++++------ paddle/fluid/operators/yolov3_loss_op.h | 4 ++-- .../fluid/tests/unittests/test_yolov3_loss_op.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 1d7f482362..e7597f7324 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -215,9 +215,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); -REGISTER_OP_CPU_KERNEL( - yolov3_loss, - ops::Yolov3LossKernel); -REGISTER_OP_CPU_KERNEL( - yolov3_loss_grad, - ops::Yolov3LossGradKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index a1072aca10..0bb285722d 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -323,7 +323,7 @@ static void AddAllGradToInputGrad( } } -template +template class Yolov3LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -392,7 +392,7 @@ class Yolov3LossKernel : public framework::OpKernel { } }; -template +template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 335214b298..544fe4b4f8 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -195,7 +195,7 @@ class TestYolov3LossOp(OpTest): self.check_grad_with_place( place, ['X'], 'Loss', - no_grad_set=set("GTBox"), + no_grad_set=set(["GTBox", "GTLabel"]), max_relative_error=0.06) def initTestCase(self):