From 10dd3b37ad26660bbd9c52c111039688e6b063b5 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Thu, 17 Jan 2019 12:13:34 +0000 Subject: [PATCH 01/53] add axis for box coder op --- paddle/fluid/API.spec | 2 +- .../fluid/operators/detection/box_coder_op.cc | 40 +++- .../fluid/operators/detection/box_coder_op.cu | 83 ++++++--- .../fluid/operators/detection/box_coder_op.h | 76 +++++--- python/paddle/fluid/layers/detection.py | 9 +- .../tests/unittests/test_box_coder_op.py | 176 ++++++++++++++---- 6 files changed, 282 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 50ffef72ba..7068a37ef0 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -315,7 +315,7 @@ paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'tr paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) +paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'axis', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, 0, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 06fbb9815c..5db600b19a 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -32,31 +32,53 @@ class BoxCoderOp : public framework::OperatorWithKernel { if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, - "The rank of Input of PriorBoxVar must be 2"); + "The rank of Input of PriorBox must be 2"); PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); if (ctx->HasInput("PriorBoxVar")) { auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); - PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims); + PADDLE_ENFORCE( + prior_box_var_dims.size() == 1 || prior_box_var_dims.size() == 2, + "Input(PriorBoxVar) of BoxCoderOp should be 1 or 2."); + if (prior_box_var_dims.size() == 1) { + PADDLE_ENFORCE_EQ( + prior_box_var_dims[0], 4, + "The 1st dimension of Input(PriorBoxVar) should be 1" + "when the rank is 1."); + } else { + PADDLE_ENFORCE_EQ( + prior_box_dims, prior_box_var_dims, + "The dimension of Input(PriorBoxVar) should be equal to" + "the dimension of Input(PriorBox when the rank is 2.)"); + } } auto code_type = GetBoxCodeType(ctx->Attrs().Get("code_type")); + int axis = ctx->Attrs().Get("axis"); if (code_type == BoxCodeType::kEncodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, "The rank of Input of TargetBox must be 2"); PADDLE_ENFORCE_EQ(target_box_dims[1], 4, "The shape of TargetBox is [M, 4]"); + ctx->SetOutputDim( + "OutputBox", + framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); } else if (code_type == BoxCodeType::kDecodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, "The rank of Input of TargetBox must be 3"); - PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + if (axis == 0) { + PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + } else if (axis == 1) { + PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); + } else { + PADDLE_THROW("axis must be 0 or 1."); + } PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); + ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); } } - ctx->SetOutputDim( - "OutputBox", - framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); } }; @@ -100,6 +122,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default true) " "whether treat the priorbox as a noramlized box") .SetDefault(true); + AddAttr("axis", + "(int, default 1)" + "which axis to broadcast for box decode, it is only valid" + "when code type is decode_center_size") + .SetDefault(0) + .InEnum({0, 1}); AddOutput("OutputBox", "(LoDTensor or Tensor) " "When code_type is 'encode_center_size', the output tensor of " diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index a7af111f63..ca62afd8ed 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -20,7 +20,8 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, - const bool normalized, T* output) { + const bool normalized, + const T prior_box_var_size, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; @@ -30,11 +31,9 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, T prior_box_height = prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1] + (normalized == false); - T prior_box_center_x = - (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; - T prior_box_center_y = (prior_box_data[col_idx * len + 3] + - prior_box_data[col_idx * len + 1]) / - 2; + T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2; + T prior_box_center_y = + prior_box_data[col_idx * len + 1] + prior_box_height / 2; T target_box_center_x = (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / @@ -55,10 +54,14 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)); output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)); if (prior_box_var_data) { - output[idx * len] /= prior_box_var_data[col_idx * len]; - output[idx * len + 1] /= prior_box_var_data[col_idx * len + 1]; - output[idx * len + 2] /= prior_box_var_data[col_idx * len + 2]; - output[idx * len + 3] /= prior_box_var_data[col_idx * len + 3]; + int prior_var_offset = 0; + if (prior_box_var_size == 2) { + prior_var_offset = col_idx * len; + } + output[idx * len] /= prior_box_var_data[prior_var_offset]; + output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; + output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; + output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; } } } @@ -68,33 +71,48 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, - const bool normalized, T* output) { + const bool normalized, + const T prior_box_var_size, + const int axis, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; + int prior_box_offset = 0; if (idx < row * col) { const int col_idx = idx % col; - T prior_box_width = prior_box_data[col_idx * len + 2] - - prior_box_data[col_idx * len] + (normalized == false); - T prior_box_height = prior_box_data[col_idx * len + 3] - - prior_box_data[col_idx * len + 1] + + const int row_idx = idx / col; + if (axis == 0) + prior_box_offset = col_idx * len; + else if (axis == 1) + prior_box_offset = row_idx * len; + T prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + T prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + (normalized == false); T prior_box_center_x = - (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; - T prior_box_center_y = (prior_box_data[col_idx * len + 3] + - prior_box_data[col_idx * len + 1]) / - 2; + prior_box_data[prior_box_offset] + prior_box_width / 2; + T prior_box_center_y = + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; T target_box_width, target_box_height; T target_box_center_x, target_box_center_y; if (prior_box_var_data) { - target_box_width = exp(prior_box_var_data[col_idx * len + 2] * + int prior_var_offset = 0; + if (prior_box_var_size == 2) { + if (axis == 0) + prior_var_offset = col_idx * len; + else if (axis == 1) + prior_var_offset = row_idx * len; + } + target_box_width = exp(prior_box_var_data[prior_var_offset + 2] * target_box_data[idx * len + 2]) * prior_box_width; - target_box_height = exp(prior_box_var_data[col_idx * len + 3] * + target_box_height = exp(prior_box_var_data[prior_var_offset + 3] * target_box_data[idx * len + 3]) * prior_box_height; - target_box_center_x = prior_box_var_data[col_idx * len] * + target_box_center_x = prior_box_var_data[prior_var_offset] * target_box_data[idx * len] * prior_box_width + prior_box_center_x; - target_box_center_y = prior_box_var_data[col_idx * len + 1] * + target_box_center_y = prior_box_var_data[prior_var_offset + 1] * target_box_data[idx * len + 1] * prior_box_height + prior_box_center_y; @@ -131,14 +149,25 @@ class BoxCoderCUDAKernel : public framework::OpKernel { const T* prior_box_data = prior_box->data(); const T* target_box_data = target_box->data(); const T* prior_box_var_data = nullptr; - if (prior_box_var) prior_box_var_data = prior_box_var->data(); + auto prior_box_var_size = 0; + if (prior_box_var) { + prior_box_var_data = prior_box_var->data(); + prior_box_var_size = prior_box_var->dims().size(); + } if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } + auto code_type = GetBoxCodeType(context.Attr("code_type")); + bool normalized = context.Attr("box_normalized"); + int axis = context.Attr("axis"); + auto row = target_box->dims()[0]; auto col = prior_box->dims()[0]; + if (code_type == BoxCodeType::kDecodeCenterSize) { + col = target_box->dims()[1]; + } auto len = prior_box->dims()[1]; int block = 512; int grid = (row * col + block - 1) / block; @@ -147,16 +176,14 @@ class BoxCoderCUDAKernel : public framework::OpKernel { output_box->mutable_data({row, col, len}, context.GetPlace()); T* output = output_box->data(); - auto code_type = GetBoxCodeType(context.Attr("code_type")); - bool normalized = context.Attr("box_normalized"); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, output); + normalized, prior_box_var_size, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, output); + normalized, prior_box_var_size, axis, output); } } }; diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index b2a2bcdce9..986869d8a3 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -53,10 +53,9 @@ class BoxCoderKernel : public framework::OpKernel { T prior_box_height = prior_box_data[j * len + 3] - prior_box_data[j * len + 1] + (normalized == false); - T prior_box_center_x = - (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; + T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2; T prior_box_center_y = - (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; + prior_box_data[j * len + 1] + prior_box_height / 2; T target_box_center_x = (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; @@ -78,10 +77,14 @@ class BoxCoderKernel : public framework::OpKernel { output[offset + 3] = std::log(std::fabs(target_box_height / prior_box_height)); if (prior_box_var) { - output[offset] /= prior_box_var_data[j * len]; - output[offset + 1] /= prior_box_var_data[j * len + 1]; - output[offset + 2] /= prior_box_var_data[j * len + 2]; - output[offset + 3] /= prior_box_var_data[j * len + 3]; + int prior_var_offset = 0; + if (prior_box_var->dims().size() == 2) { + prior_var_offset = j * len; + } + output[offset] /= prior_box_var_data[prior_var_offset]; + output[offset + 1] /= prior_box_var_data[prior_var_offset + 1]; + output[offset + 2] /= prior_box_var_data[prior_var_offset + 2]; + output[offset + 3] /= prior_box_var_data[prior_var_offset + 3]; } } } @@ -89,48 +92,63 @@ class BoxCoderKernel : public framework::OpKernel { void DecodeCenterSize(const framework::Tensor* target_box, const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, - const bool normalized, T* output) const { + const bool normalized, const int axis, + T* output) const { int64_t row = target_box->dims()[0]; - int64_t col = prior_box->dims()[0]; - int64_t len = prior_box->dims()[1]; + int64_t col = target_box->dims()[1]; + int64_t len = target_box->dims()[2]; auto* target_box_data = target_box->data(); auto* prior_box_data = prior_box->data(); const T* prior_box_var_data = nullptr; if (prior_box_var) prior_box_var_data = prior_box_var->data(); - + int prior_box_offset = 0; #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(2) #endif for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { size_t offset = i * col * len + j * len; - T prior_box_width = prior_box_data[j * len + 2] - - prior_box_data[j * len] + (normalized == false); - T prior_box_height = prior_box_data[j * len + 3] - - prior_box_data[j * len + 1] + + if (axis == 0) { + prior_box_offset = j * len; + } else if (axis == 1) { + prior_box_offset = i * len; + } + T prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + T prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + (normalized == false); T prior_box_center_x = - (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; + prior_box_data[prior_box_offset] + prior_box_width / 2; T prior_box_center_y = - (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; T target_box_center_x = 0, target_box_center_y = 0; T target_box_width = 0, target_box_height = 0; if (prior_box_var) { - target_box_center_x = prior_box_var_data[j * len] * + int prior_var_offset = 0; + if (prior_box_var->dims().size() == 2) { + if (axis == 0) + prior_var_offset = j * len; + else if (axis == 1) + prior_var_offset = i * len; + } + target_box_center_x = prior_box_var_data[prior_var_offset] * target_box_data[offset] * prior_box_width + prior_box_center_x; - target_box_center_y = prior_box_var_data[j * len + 1] * + target_box_center_y = prior_box_var_data[prior_var_offset + 1] * target_box_data[offset + 1] * prior_box_height + prior_box_center_y; - target_box_width = std::exp(prior_box_var_data[j * len + 2] * + target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] * target_box_data[offset + 2]) * prior_box_width; - target_box_height = std::exp(prior_box_var_data[j * len + 3] * - target_box_data[offset + 3]) * - prior_box_height; + target_box_height = + std::exp(prior_box_var_data[prior_var_offset + 3] * + target_box_data[offset + 3]) * + prior_box_height; } else { target_box_center_x = target_box_data[offset] * prior_box_width + prior_box_center_x; @@ -157,25 +175,29 @@ class BoxCoderKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); - + const int axis = context.Attr("axis"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, "Only support 1 level of LoD."); } + auto code_type = GetBoxCodeType(context.Attr("code_type")); + bool normalized = context.Attr("box_normalized"); + auto row = target_box->dims()[0]; auto col = prior_box->dims()[0]; + if (code_type == BoxCodeType::kDecodeCenterSize) { + col = target_box->dims()[1]; + } auto len = prior_box->dims()[1]; output_box->mutable_data({row, col, len}, context.GetPlace()); - auto code_type = GetBoxCodeType(context.Attr("code_type")); - bool normalized = context.Attr("box_normalized"); T* output = output_box->data(); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { - DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, + DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis, output); } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 8aed97dc59..c844050c5d 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -342,6 +342,7 @@ def box_coder(prior_box, target_box, code_type="encode_center_size", box_normalized=True, + axis=0, name=None): """ ${comment} @@ -352,6 +353,7 @@ def box_coder(prior_box, target_box(${target_box_type}): ${target_box_comment} code_type(${code_type_type}): ${code_type_comment} box_normalized(${box_normalized_type}): ${box_normalized_comment} + axis(${axis_type}): ${axis_comment} Returns: output_box(${output_box_type}): ${output_box_comment} @@ -372,8 +374,11 @@ def box_coder(prior_box, "PriorBoxVar": prior_box_var, "TargetBox": target_box }, - attrs={"code_type": code_type, - "box_normalized": box_normalized}, + attrs={ + "code_type": code_type, + "box_normalized": box_normalized, + "axis": axis + }, outputs={"OutputBox": output_box}) return output_box diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 2511c5c22e..b6f6bc1450 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -21,22 +21,32 @@ import math from op_test import OpTest -def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, - box_normalized): - prior_box_x = ( - (prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0]) - prior_box_y = ( - (prior_box[:, 3] + prior_box[:, 1]) / 2).reshape(1, prior_box.shape[0]) - prior_box_width = ( - (prior_box[:, 2] - prior_box[:, 0])).reshape(1, prior_box.shape[0]) - prior_box_height = ( - (prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0]) - prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], - prior_box_var.shape[1]) - if not box_normalized: - prior_box_height = prior_box_height + 1 - prior_box_width = prior_box_width + 1 - +def box_coder(target_box, + prior_box, + prior_box_var, + output_box, + code_type, + box_normalized, + axis=0): + prior_box_width = prior_box[:, 2] - prior_box[:, 0] + \ + (box_normalized==False) + prior_box_height = prior_box[:, 3] - prior_box[:, 1] + \ + (box_normalized==False) + prior_box_x = prior_box_width * 0.5 + prior_box[:, 0] + prior_box_y = prior_box_height * 0.5 + prior_box[:, 1] + if axis == 0: + prior_box_width = prior_box_width.reshape(1, prior_box.shape[0]) + prior_box_height = prior_box_height.reshape(1, prior_box.shape[0]) + prior_box_x = prior_box_x.reshape(1, prior_box.shape[0]) + prior_box_y = prior_box_y.reshape(1, prior_box.shape[0]) + else: + prior_box_width = prior_box_width.reshape(prior_box.shape[0], 1) + prior_box_height = prior_box_height.reshape(prior_box.shape[0], 1) + prior_box_x = prior_box_x.reshape(prior_box.shape[0], 1) + prior_box_y = prior_box_y.reshape(prior_box.shape[0], 1) + if prior_box_var.ndim == 2: + prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], + prior_box_var.shape[1]) if (code_type == "EncodeCenterSize"): target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape( target_box.shape[0], 1) @@ -49,26 +59,52 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, if not box_normalized: target_box_height = target_box_height + 1 target_box_width = target_box_width + 1 - - output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \ - prior_box_var[:,:,0] - output_box[:,:,1] = (target_box_y - prior_box_y) / prior_box_height / \ - prior_box_var[:,:,1] - output_box[:,:,2] = np.log(np.fabs(target_box_width / prior_box_width)) / \ - prior_box_var[:,:,2] - output_box[:,:,3] = np.log(np.fabs(target_box_height / prior_box_height)) / \ - prior_box_var[:,:,3] + if prior_box_var.ndim == 1: + output_box[:,:,0] = (target_box_x - prior_box_x) / \ + prior_box_width / \ + prior_box_var[0] + output_box[:,:,1] = (target_box_y - prior_box_y) / \ + prior_box_height / \ + prior_box_var[1] + output_box[:,:,2] = np.log(np.fabs(target_box_width / \ + prior_box_width)) / \ + prior_box_var[2] + output_box[:,:,3] = np.log(np.fabs(target_box_height / \ + prior_box_height)) / \ + prior_box_var[3] + else: + output_box[:,:,0] = (target_box_x - prior_box_x) / \ + prior_box_width / \ + prior_box_var[:,:,0] + output_box[:,:,1] = (target_box_y - prior_box_y) / \ + prior_box_height / \ + prior_box_var[:,:,1] + output_box[:,:,2] = np.log(np.fabs(target_box_width / \ + prior_box_width)) / \ + prior_box_var[:,:,2] + output_box[:,:,3] = np.log(np.fabs(target_box_height / \ + prior_box_height)) / \ + prior_box_var[:,:,3] elif (code_type == "DecodeCenterSize"): - target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ - prior_box_width + prior_box_x - target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ - prior_box_height + prior_box_y - target_box_width = np.exp(prior_box_var[:,:,2] * target_box[:,:,2]) * \ - prior_box_width - target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \ - prior_box_height - + if prior_box_var.ndim == 1: + target_box_x = prior_box_var[0] * target_box[:,:,0] * \ + prior_box_width + prior_box_x + target_box_y = prior_box_var[1] * target_box[:,:,1] * \ + prior_box_height + prior_box_y + target_box_width = np.exp(prior_box_var[2] * target_box[:,:,2]) * \ + prior_box_width + target_box_height = np.exp(prior_box_var[3] * target_box[:,:,3]) * \ + prior_box_height + else: + target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ + prior_box_width + prior_box_x + target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ + prior_box_height + prior_box_y + target_box_width = np.exp(prior_box_var[:,:,2] * \ + target_box[:,:,2]) * prior_box_width + target_box_height = np.exp(prior_box_var[:,:,3] * \ + target_box[:,:,3]) * prior_box_height output_box[:, :, 0] = target_box_x - target_box_width / 2 output_box[:, :, 1] = target_box_y - target_box_height / 2 output_box[:, :, 2] = target_box_x + target_box_width / 2 @@ -78,10 +114,17 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, output_box[:, :, 3] = output_box[:, :, 3] - 1 -def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type, - box_normalized): +def batch_box_coder(prior_box, + prior_box_var, + target_box, + lod, + code_type, + box_normalized, + axis=0): n = target_box.shape[0] m = prior_box.shape[0] + if code_type == "DecodeCenterSize": + m = target_box.shape[1] output_box = np.zeros((n, m, 4), dtype=np.float32) cur_offset = 0 for i in range(len(lod)): @@ -91,10 +134,8 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type, output_box[cur_offset:(cur_offset + lod[i]), :, :], code_type, box_normalized) elif (code_type == "DecodeCenterSize"): - box_coder(target_box[cur_offset:(cur_offset + lod[i]), :, :], - prior_box, prior_box_var, - output_box[cur_offset:(cur_offset + lod[i]), :, :], - code_type, box_normalized) + box_coder(target_box, prior_box, prior_box_var, output_box, + code_type, box_normalized, axis) cur_offset += lod[i] return output_box @@ -111,6 +152,32 @@ class TestBoxCoderOp(OpTest): target_box = np.random.random((5, 10, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type, box_normalized) + self.inputs = { + 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, + 'TargetBox': target_box, + } + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False + } + self.outputs = {'OutputBox': output_box} + + +class TestBoxCoderOpWithOneRankVar(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[1, 1, 1, 1, 1]] + prior_box = np.random.random((6, 4)).astype('float32') + prior_box_var = np.random.random((4)).astype('float32') + target_box = np.random.random((3, 6, 4)).astype('float32') + code_type = "DecodeCenterSize" + box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, lod[0], code_type, box_normalized) @@ -176,5 +243,34 @@ class TestBoxCoderOpWithLoD(OpTest): self.outputs = {'OutputBox': output_box} +class TestBoxCoderOpWithAxis(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[1, 1, 1, 1, 1]] + prior_box = np.random.random((5, 4)).astype('float32') + prior_box_var = np.random.random((4)).astype('float32') + target_box = np.random.random((5, 6, 4)).astype('float32') + code_type = "DecodeCenterSize" + box_normalized = False + axis = 1 + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type, box_normalized, axis) + + self.inputs = { + 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, + 'TargetBox': target_box, + } + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False, + 'axis': axis + } + self.outputs = {'OutputBox': output_box} + + if __name__ == '__main__': unittest.main() From ab9d6a4f39ee8fefceb7392f1b93131eed8db9dc Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Thu, 17 Jan 2019 12:20:18 +0000 Subject: [PATCH 02/53] add comments, test=develop --- paddle/fluid/operators/detection/box_coder_op.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 5db600b19a..e342417491 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -166,7 +166,11 @@ where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the -encoded/decoded coordinates, width and height. +encoded/decoded coordinates, width and height. + +During Box Decoding, two modes for broadcast are supported. Say target box has +shape [N, M, 4], and the shape of prior box can be [N, 4] or [M, 4]. Then prior +box will broadcast to target box along the assigned axis. )DOC"); } }; From 0d915078597f483057b25cdc2e99bdd9bee71f71 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 05:22:47 +0000 Subject: [PATCH 03/53] fix share lod, test=develop --- paddle/fluid/operators/detection/box_coder_op.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index e342417491..b4b02124cc 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -77,9 +77,13 @@ class BoxCoderOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); } - } - ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { + ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + } else { + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + } + } } }; From 66bb5dd760f0ce72740ca755224bb3ca85194600 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 10:18:41 +0000 Subject: [PATCH 04/53] refine infer shape, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index b4b02124cc..2ce844669b 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -43,7 +43,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { if (prior_box_var_dims.size() == 1) { PADDLE_ENFORCE_EQ( prior_box_var_dims[0], 4, - "The 1st dimension of Input(PriorBoxVar) should be 1" + "The 1st dimension of Input(PriorBoxVar) should be 4" "when the rank is 1."); } else { PADDLE_ENFORCE_EQ( @@ -52,37 +52,36 @@ class BoxCoderOp : public framework::OperatorWithKernel { "the dimension of Input(PriorBox when the rank is 2.)"); } } + } - auto code_type = - GetBoxCodeType(ctx->Attrs().Get("code_type")); - int axis = ctx->Attrs().Get("axis"); - if (code_type == BoxCodeType::kEncodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, - "The rank of Input of TargetBox must be 2"); - PADDLE_ENFORCE_EQ(target_box_dims[1], 4, - "The shape of TargetBox is [M, 4]"); - ctx->SetOutputDim( - "OutputBox", - framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); - } else if (code_type == BoxCodeType::kDecodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, - "The rank of Input of TargetBox must be 3"); - if (axis == 0) { - PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); - } else if (axis == 1) { - PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); - } else { - PADDLE_THROW("axis must be 0 or 1."); - } - PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); - ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); - } - - if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { - ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + auto code_type = GetBoxCodeType(ctx->Attrs().Get("code_type")); + int axis = ctx->Attrs().Get("axis"); + if (code_type == BoxCodeType::kEncodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, + "The rank of Input of TargetBox must be 2"); + PADDLE_ENFORCE_EQ(target_box_dims[1], 4, + "The shape of TargetBox is [M, 4]"); + ctx->SetOutputDim( + "OutputBox", + framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); + } else if (code_type == BoxCodeType::kDecodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, + "The rank of Input of TargetBox must be 3"); + if (axis == 0) { + PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + } else if (axis == 1) { + PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); } else { - ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + PADDLE_THROW("axis must be 0 or 1."); } + PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); + ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); + } + + if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { + ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + } else { + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); } } }; From 0d4b60ab8bc8d1db9fdef1a6228663c3f60a3980 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 12:25:07 +0000 Subject: [PATCH 05/53] add lod for slice op, test=develop --- paddle/fluid/operators/slice_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 789e61b2d3..94995fc996 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -54,6 +54,9 @@ class SliceOp : public framework::OperatorWithKernel { out_dims[axes[i]] = end - start; } ctx->SetOutputDim("Out", out_dims); + if (axes[0] != 0) { + ctx->ShareLoD("Input", /*->*/ "Out"); + } } protected: From c12a969bd446691d107ab1607be529ef9388bcd0 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Tue, 22 Jan 2019 13:27:21 +0000 Subject: [PATCH 06/53] refine comment and unittest, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 13 +- .../fluid/operators/detection/box_coder_op.cu | 10 +- python/paddle/fluid/layers/detection.py | 4 +- .../tests/unittests/test_box_coder_op.py | 175 +++++++----------- 4 files changed, 79 insertions(+), 123 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 2ce844669b..f89f87663b 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -32,7 +32,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, - "The rank of Input of PriorBox must be 2"); + "The rank of Input PriorBox must be 2"); PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); if (ctx->HasInput("PriorBoxVar")) { @@ -58,7 +58,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { int axis = ctx->Attrs().Get("axis"); if (code_type == BoxCodeType::kEncodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, - "The rank of Input of TargetBox must be 2"); + "The rank of Input TargetBox must be 2"); PADDLE_ENFORCE_EQ(target_box_dims[1], 4, "The shape of TargetBox is [M, 4]"); ctx->SetOutputDim( @@ -66,7 +66,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); } else if (code_type == BoxCodeType::kDecodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, - "The rank of Input of TargetBox must be 3"); + "The rank of Input TargetBox must be 3"); if (axis == 0) { PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); } else if (axis == 1) { @@ -126,8 +126,11 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "whether treat the priorbox as a noramlized box") .SetDefault(true); AddAttr("axis", - "(int, default 1)" - "which axis to broadcast for box decode, it is only valid" + "(int, default 0)" + "which axis in PriorBox to broadcast for box decode," + "for example, if axis is 0 and TargetBox has shape" + "[N, M, 4] and PriorBox has shape [M, 4], then PriorBox " + "will broadcast to [N, M, 4] for decoding. It is only valid" "when code type is decode_center_size") .SetDefault(0) .InEnum({0, 1}); diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index ca62afd8ed..0b64224e1e 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -79,10 +79,7 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, if (idx < row * col) { const int col_idx = idx % col; const int row_idx = idx / col; - if (axis == 0) - prior_box_offset = col_idx * len; - else if (axis == 1) - prior_box_offset = row_idx * len; + prior_box_offset = axis == 0 ? col_idx * len : row_idx * len; T prior_box_width = prior_box_data[prior_box_offset + 2] - prior_box_data[prior_box_offset] + (normalized == false); @@ -98,10 +95,7 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, if (prior_box_var_data) { int prior_var_offset = 0; if (prior_box_var_size == 2) { - if (axis == 0) - prior_var_offset = col_idx * len; - else if (axis == 1) - prior_var_offset = row_idx * len; + prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; } target_box_width = exp(prior_box_var_data[prior_var_offset + 2] * target_box_data[idx * len + 2]) * diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index c844050c5d..8c8a6c6223 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -342,8 +342,8 @@ def box_coder(prior_box, target_box, code_type="encode_center_size", box_normalized=True, - axis=0, - name=None): + name=None, + axis=0): """ ${comment} diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index b6f6bc1450..6f7930c921 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -21,121 +21,80 @@ import math from op_test import OpTest -def box_coder(target_box, - prior_box, - prior_box_var, - output_box, - code_type, - box_normalized, - axis=0): - prior_box_width = prior_box[:, 2] - prior_box[:, 0] + \ - (box_normalized==False) - prior_box_height = prior_box[:, 3] - prior_box[:, 1] + \ - (box_normalized==False) - prior_box_x = prior_box_width * 0.5 + prior_box[:, 0] - prior_box_y = prior_box_height * 0.5 + prior_box[:, 1] - if axis == 0: - prior_box_width = prior_box_width.reshape(1, prior_box.shape[0]) - prior_box_height = prior_box_height.reshape(1, prior_box.shape[0]) - prior_box_x = prior_box_x.reshape(1, prior_box.shape[0]) - prior_box_y = prior_box_y.reshape(1, prior_box.shape[0]) +def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0): + pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False) + pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False) + pb_x = pb_w * 0.5 + p_box[:, 0] + pb_y = pb_h * 0.5 + p_box[:, 1] + shape = (1, p_box.shape[0]) if axis == 0 else (p_box.shape[0], 1) + + pb_w = pb_w.reshape(shape) + pb_h = pb_h.reshape(shape) + pb_x = pb_x.reshape(shape) + pb_y = pb_y.reshape(shape) + + if pb_v.ndim == 2: + pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1]) + if pb_v.ndim == 1: + tb_x = pb_v[0] * t_box[:, :, 0] * pb_w + pb_x + tb_y = pb_v[1] * t_box[:, :, 1] * pb_h + pb_y + tb_w = np.exp(pb_v[2] * t_box[:, :, 2]) * pb_w + tb_h = np.exp(pb_v[3] * t_box[:, :, 3]) * pb_h else: - prior_box_width = prior_box_width.reshape(prior_box.shape[0], 1) - prior_box_height = prior_box_height.reshape(prior_box.shape[0], 1) - prior_box_x = prior_box_x.reshape(prior_box.shape[0], 1) - prior_box_y = prior_box_y.reshape(prior_box.shape[0], 1) - if prior_box_var.ndim == 2: - prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], - prior_box_var.shape[1]) - if (code_type == "EncodeCenterSize"): - target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape( - target_box.shape[0], 1) - target_box_y = ((target_box[:, 3] + target_box[:, 1]) / 2).reshape( - target_box.shape[0], 1) - target_box_width = ((target_box[:, 2] - target_box[:, 0])).reshape( - target_box.shape[0], 1) - target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape( - target_box.shape[0], 1) - if not box_normalized: - target_box_height = target_box_height + 1 - target_box_width = target_box_width + 1 - if prior_box_var.ndim == 1: - output_box[:,:,0] = (target_box_x - prior_box_x) / \ - prior_box_width / \ - prior_box_var[0] - output_box[:,:,1] = (target_box_y - prior_box_y) / \ - prior_box_height / \ - prior_box_var[1] - output_box[:,:,2] = np.log(np.fabs(target_box_width / \ - prior_box_width)) / \ - prior_box_var[2] - output_box[:,:,3] = np.log(np.fabs(target_box_height / \ - prior_box_height)) / \ - prior_box_var[3] - else: - output_box[:,:,0] = (target_box_x - prior_box_x) / \ - prior_box_width / \ - prior_box_var[:,:,0] - output_box[:,:,1] = (target_box_y - prior_box_y) / \ - prior_box_height / \ - prior_box_var[:,:,1] - output_box[:,:,2] = np.log(np.fabs(target_box_width / \ - prior_box_width)) / \ - prior_box_var[:,:,2] - output_box[:,:,3] = np.log(np.fabs(target_box_height / \ - prior_box_height)) / \ - prior_box_var[:,:,3] - - elif (code_type == "DecodeCenterSize"): - if prior_box_var.ndim == 1: - target_box_x = prior_box_var[0] * target_box[:,:,0] * \ - prior_box_width + prior_box_x - target_box_y = prior_box_var[1] * target_box[:,:,1] * \ - prior_box_height + prior_box_y - target_box_width = np.exp(prior_box_var[2] * target_box[:,:,2]) * \ - prior_box_width - target_box_height = np.exp(prior_box_var[3] * target_box[:,:,3]) * \ - prior_box_height - else: - target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ - prior_box_width + prior_box_x - target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ - prior_box_height + prior_box_y - target_box_width = np.exp(prior_box_var[:,:,2] * \ - target_box[:,:,2]) * prior_box_width - target_box_height = np.exp(prior_box_var[:,:,3] * \ - target_box[:,:,3]) * prior_box_height - output_box[:, :, 0] = target_box_x - target_box_width / 2 - output_box[:, :, 1] = target_box_y - target_box_height / 2 - output_box[:, :, 2] = target_box_x + target_box_width / 2 - output_box[:, :, 3] = target_box_y + target_box_height / 2 - if not box_normalized: - output_box[:, :, 2] = output_box[:, :, 2] - 1 - output_box[:, :, 3] = output_box[:, :, 3] - 1 - - -def batch_box_coder(prior_box, - prior_box_var, - target_box, - lod, - code_type, - box_normalized, - axis=0): - n = target_box.shape[0] - m = prior_box.shape[0] + tb_x = pb_v[:, :, 0] * t_box[:, :, 0] * pb_w + pb_x + tb_y = pb_v[:, :, 1] * t_box[:, :, 1] * pb_h + pb_y + tb_w = np.exp(pb_v[:, :, 2] * t_box[:, :, 2]) * pb_w + tb_h = np.exp(pb_v[:, :, 3] * t_box[:, :, 3]) * pb_h + output_box[:, :, 0] = tb_x - tb_w / 2 + output_box[:, :, 1] = tb_y - tb_h / 2 + output_box[:, :, 2] = tb_x + tb_w / 2 - (not norm) + output_box[:, :, 3] = tb_y + tb_h / 2 - (not norm) + + +def box_encoder(t_box, p_box, pb_v, output_box, norm): + pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False) + pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False) + pb_x = pb_w * 0.5 + p_box[:, 0] + pb_y = pb_h * 0.5 + p_box[:, 1] + shape = (1, p_box.shape[0]) + + pb_w = pb_w.reshape(shape) + pb_h = pb_h.reshape(shape) + pb_x = pb_x.reshape(shape) + pb_y = pb_y.reshape(shape) + + if pb_v.ndim == 2: + pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1]) + tb_x = ((t_box[:, 2] + t_box[:, 0]) / 2).reshape(t_box.shape[0], 1) + tb_y = ((t_box[:, 3] + t_box[:, 1]) / 2).reshape(t_box.shape[0], 1) + tb_w = (t_box[:, 2] - t_box[:, 0]).reshape(t_box.shape[0], 1) + (not norm) + tb_h = (t_box[:, 3] - t_box[:, 1]).reshape(t_box.shape[0], 1) + (not norm) + if pb_v.ndim == 1: + output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[0] + output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[1] + output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[2] + output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[3] + else: + output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[:, :, 0] + output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[:, :, 1] + output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[:, :, 2] + output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[:, :, 3] + + +def batch_box_coder(p_box, pb_v, t_box, lod, code_type, norm, axis=0): + n = t_box.shape[0] + m = p_box.shape[0] if code_type == "DecodeCenterSize": - m = target_box.shape[1] + m = t_box.shape[1] output_box = np.zeros((n, m, 4), dtype=np.float32) cur_offset = 0 for i in range(len(lod)): if (code_type == "EncodeCenterSize"): - box_coder(target_box[cur_offset:(cur_offset + lod[i]), :], - prior_box, prior_box_var, - output_box[cur_offset:(cur_offset + lod[i]), :, :], - code_type, box_normalized) + box_encoder(t_box[cur_offset:(cur_offset + lod[i]), :], p_box, pb_v, + output_box[cur_offset:(cur_offset + lod[i]), :, :], + norm) elif (code_type == "DecodeCenterSize"): - box_coder(target_box, prior_box, prior_box_var, output_box, - code_type, box_normalized, axis) + box_decoder(t_box, p_box, pb_v, output_box, norm, axis) cur_offset += lod[i] return output_box From f44b1507f0a3ab7d8aef7cd2b23b8cc90a55f355 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Wed, 23 Jan 2019 02:21:10 +0000 Subject: [PATCH 07/53] revised API spec, test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 7068a37ef0..cdb0397ecd 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -315,7 +315,7 @@ paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'tr paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'axis', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, 0, None)) +paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) From 48cc4846430eefcd0d1b03349b982675ce853091 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 23 Jan 2019 19:27:55 +0800 Subject: [PATCH 08/53] add align_corners and align_mode for image_resize test=develop --- paddle/fluid/operators/interpolate_op.cc | 73 ++++++ paddle/fluid/operators/interpolate_op.cu | 96 +++++--- paddle/fluid/operators/interpolate_op.h | 102 ++++++--- python/paddle/fluid/layers/nn.py | 207 +++++++++++++++++- .../unittests/test_bilinear_interp_op.py | 94 ++++++-- .../tests/unittests/test_nearest_interp_op.py | 57 ++++- 6 files changed, 529 insertions(+), 100 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 93dd3f794f..1b34d404c0 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -82,6 +82,18 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { "bilinear interpolation and \"nearest\" for nearest " "neighbor interpolation.") .SetDefault("bilinear"); + AddAttr( + "align_corners", + "an optinal bool. Defaults to True. " + "If True, the centers of 4 corner pixels of the input and output " + "tensors are aligned, preserving the values at the corner pixels, " + "if Flase, are not aligned") + .SetDefault(true); + AddAttr("align_mode", + "(int, default \'0\'), align_corners mode , can be \'0\' " + "for pytorch calculation method, can be \'1\' for " + "tensorflow calculation method.") + .SetDefault(0); AddComment(R"DOC( This operator samples input X to given output shape by using specified interpolation method, the interpolation methods can be \"nearest\" @@ -98,6 +110,67 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { to perform linear interpolation first in one direction, and then again in the other direction. + Align_corners and align_mode are optinal parameters,The calculation method + of interpolation can be selected by them. + + Example: + + for scale: + + if align_corners = True and out_{size}>1 : + + scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0) + + else: + + scale_{factor} = float(in_{size}/out_{size}) + + + Nearest neighbor interpolation: + + case 1: + align_corners = False + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor + W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor + + case 2: + align_corners = True + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = round(H_{in} * scale_{factor}) + W_out = round(W_{in} * scale_{factor}) + + Bilinear interpolation: + + case 1: + align_corners = False , align_mode = 0 + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + + case 2: + align_corners = False , align_mode = 1 + or + align_corners = True + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + + For details of nearest neighbor interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 99ac725f73..316811d23e 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -23,7 +23,8 @@ __global__ void KeNearestNeighborInterpFw( const T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const float ratio_h, const float ratio_w) { + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; @@ -35,10 +36,14 @@ __global__ void KeNearestNeighborInterpFw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = static_cast(ratio_h * out_img_idy + 0.5); + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); int out_img_idx = tid % out_img_w; - int in_img_idx = static_cast(ratio_w * out_img_idx + 0.5); + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); out[tid] = in[out_id_h * input_w + channel_id * in_img_size + in_img_idy * in_img_w + in_img_idx]; @@ -50,7 +55,8 @@ __global__ void KeNearestNeighborInterpBw( T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, const T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const float ratio_h, const float ratio_w) { + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; @@ -62,10 +68,14 @@ __global__ void KeNearestNeighborInterpBw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = static_cast(ratio_h * out_img_idy + 0.5); + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); int out_img_idx = tid % out_img_w; - int in_img_idx = static_cast(ratio_w * out_img_idx + 0.5); + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idy * in_img_w + in_img_idx]; @@ -79,7 +89,8 @@ __global__ void KeBilinearInterpFw( const T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const float ratio_h, const float ratio_w) { + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners, const int align_mode) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; @@ -91,15 +102,23 @@ __global__ void KeBilinearInterpFw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; + int in_img_idy = (align_mode == 0 && !align_corners) + ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) + : static_cast(ratio_h * out_img_idy); int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h1lambda = (align_mode == 0 && !align_corners) + ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy + : ratio_h * out_img_idy - in_img_idy; T h2lambda = 1.f - h1lambda; int out_img_idx = tid % out_img_w; - int in_img_idx = ratio_w * out_img_idx; + int in_img_idx = (align_mode == 0 && !align_corners) + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratio_w * out_img_idx - in_img_idx; + T w1lambda = (align_mode == 0 && !align_corners) + ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx + : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + @@ -118,7 +137,8 @@ __global__ void KeBilinearInterpBw( T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, const T* out, const size_t out_img_h, const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratio_w) { + const size_t num_channels, const T ratio_h, const T ratio_w, + const bool align_corners, const int align_mode) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; @@ -130,15 +150,24 @@ __global__ void KeBilinearInterpBw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = ratio_h * out_img_idy; + int in_img_idy = (align_mode == 0 && !align_corners) + ? ratio_h * (out_img_idy + 0.5) - 0.5 + : ratio_h * out_img_idy; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = ratio_h * out_img_idy - in_img_idy; + T h1lambda = (align_mode == 0 && !align_corners) + ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy + : ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; int out_img_idx = tid % out_img_w; - int in_img_idx = ratio_w * out_img_idx; + int in_img_idx = (align_mode == 0 && !align_corners) + ? ratio_w * (out_img_idx + 0.5) - 0.5 + : ratio_w * out_img_idx; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = ratio_w * out_img_idx - in_img_idx; + T w1lambda = (align_mode == 0 && !align_corners) + ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx + : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + @@ -175,6 +204,9 @@ class InterpolateOpCUDAKernel : public framework::OpKernel { out_w = size_data[1]; } + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + int n = input->dims()[0]; int c = input->dims()[1]; int in_h = input->dims()[2]; @@ -188,10 +220,12 @@ class InterpolateOpCUDAKernel : public framework::OpKernel { int in_chw = c * in_hw; int out_chw = c * out_hw; - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + float ratio_h = (align_corners && out_h > 1) + ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + float ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; if (in_h == out_h && in_w == out_w) { framework::TensorCopy(*input, ctx.GetPlace(), output); @@ -206,12 +240,12 @@ class InterpolateOpCUDAKernel : public framework::OpKernel { KeNearestNeighborInterpFw< T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, - out_chw, c, ratio_h, ratio_w); + out_chw, c, ratio_h, ratio_w, align_corners); } else if ("bilinear" == interp_method) { KeBilinearInterpFw< T><<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, - out_chw, c, ratio_h, ratio_w); + out_chw, c, ratio_h, ratio_w, align_corners, align_mode); } } }; @@ -234,6 +268,10 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); auto out_size = ctx.Input("OutSize"); + + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + if (out_size != nullptr) { Tensor sizes; framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); @@ -252,10 +290,12 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { int in_chw = c * in_hw; int out_chw = c * out_hw; - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + float ratio_h = (align_corners && out_h > 1) + ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + float ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; if (in_h == out_h && in_w == out_w) { framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); @@ -270,12 +310,12 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { KeNearestNeighborInterpBw< T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, - out_w, n, out_chw, c, ratio_h, ratio_w); + out_w, n, out_chw, c, ratio_h, ratio_w, align_corners); } else if ("bilinear" == interp_method) { KeBilinearInterpBw< T><<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, - out_w, n, out_chw, c, ratio_h, ratio_w); + out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode); } } }; diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index 7fdb3e1f5a..95aec33eee 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -26,14 +26,17 @@ template static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, const float ratio_h, const float ratio_w, const int n, const int c, - const int out_h, const int out_w) { + const int out_h, const int out_w, + const bool align_corners) { auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); for (int k = 0; k < out_h; k++) { // loop for images - int in_k = static_cast(ratio_h * k + 0.5); + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); for (int l = 0; l < out_w; l++) { - int in_l = static_cast(ratio_w * l + 0.5); + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); for (int i = 0; i < n; i++) { // loop for batches for (int j = 0; j < c; j++) { // loop for channels @@ -48,20 +51,29 @@ template static void BilinearInterpolation(const Tensor& input, Tensor* output, const float ratio_h, const float ratio_w, const int in_h, const int in_w, const int n, - const int c, const int out_h, - const int out_w) { + const int c, const int out_h, const int out_w, + const bool align_corners, + const bool align_mode) { auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); for (int k = 0; k < out_h; k++) { // loop for images - int y_n = static_cast(ratio_h * k); + int y_n = (align_mode == 0 && !align_corners) + ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = ratio_h * k - y_n; + float d_n = (align_mode == 0 && !align_corners) + ? ratio_h * (k + 0.5) - 0.5 - y_n + : ratio_h * k - y_n; float d_s = 1.f - d_n; for (int l = 0; l < out_w; l++) { - int x_w = static_cast(ratio_w * l); + int x_w = (align_mode == 0 && !align_corners) + ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = ratio_w * l - x_w; + float d_w = (align_mode == 0 && !align_corners) + ? ratio_w * (l + 0.5) - 0.5 - x_w + : ratio_w * l - x_w; float d_e = 1.f - d_w; for (int i = 0; i < n; i++) { // loop for batches @@ -78,19 +90,20 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, } template -static void NearestNeighborInterpolateGrad(const Tensor& output_grad, - Tensor* input_grad, - const float ratio_h, - const float ratio_w, const int n, - const int c, const int out_h, - const int out_w) { +static void NearestNeighborInterpolateGrad( + const Tensor& output_grad, Tensor* input_grad, const float ratio_h, + const float ratio_w, const int n, const int c, const int out_h, + const int out_w, const bool align_corners) { auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); + for (int k = 0; k < out_h; k++) { // loop for images - int in_k = static_cast(ratio_h * k + 0.5); + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); for (int l = 0; l < out_w; l++) { - int in_l = static_cast(ratio_w * l + 0.5); + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); for (int i = 0; i < n; i++) { // loop for batches for (int j = 0; j < c; j++) { // loop for channels @@ -106,19 +119,29 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, Tensor* input_grad, const float ratio_h, const float ratio_w, const int in_h, const int in_w, const int n, const int c, - const int out_h, const int out_w) { + const int out_h, const int out_w, + const bool align_corners, + const int align_mode) { auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); for (int k = 0; k < out_h; k++) { // loop for images - int y_n = static_cast(ratio_h * k); + int y_n = (align_mode == 0 && !align_corners) + ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = ratio_h * k - y_n; + float d_n = (align_mode == 0 && !align_corners) + ? ratio_h * (k + 0.5) - 0.5 - y_n + : ratio_h * k - y_n; float d_s = 1.f - d_n; for (int l = 0; l < out_w; l++) { - int x_w = static_cast(ratio_w * l); + int x_w = (align_mode == 0 && !align_corners) + ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = ratio_w * l - x_w; + float d_w = (align_mode == 0 && !align_corners) + ? ratio_w * (l + 0.5) - 0.5 - x_w + : ratio_w * l - x_w; float d_e = 1.f - d_w; for (int i = 0; i < n; i++) { // loop for batches @@ -134,7 +157,6 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, } } } - template class InterpolateKernel : public framework::OpKernel { public: @@ -151,6 +173,8 @@ class InterpolateKernel : public framework::OpKernel { out_h = out_size_data[0]; out_w = out_size_data[1]; } + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); const int n = input->dims()[0]; const int c = input->dims()[1]; @@ -168,17 +192,19 @@ class InterpolateKernel : public framework::OpKernel { return; } - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + float ratio_h = (align_corners && out_h > 1) + ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + float ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; if ("bilinear" == interp_method) { BilinearInterpolation(*input, output, ratio_h, ratio_w, in_h, in_w, n, - c, out_h, out_w); + c, out_h, out_w, align_corners, align_mode); } else if ("nearest" == interp_method) { NearestNeighborInterpolate(*input, output, ratio_h, ratio_w, n, c, - out_h, out_w); + out_h, out_w, align_corners); } } }; @@ -200,6 +226,8 @@ class InterpolateGradKernel : public framework::OpKernel { out_h = out_size_data[0]; out_w = out_size_data[1]; } + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); const int n = input->dims()[0]; const int c = input->dims()[1]; @@ -217,17 +245,21 @@ class InterpolateGradKernel : public framework::OpKernel { return; } - float ratio_h = - (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; - float ratio_w = - (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + float ratio_h = (align_corners && out_h > 1) + ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + float ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; if ("bilinear" == interp_method) { BilinearInterpolationGrad(*output_grad, input_grad, ratio_h, ratio_w, - in_h, in_w, n, c, out_h, out_w); + in_h, in_w, n, c, out_h, out_w, + align_corners, align_mode); } else if ("nearest" == interp_method) { NearestNeighborInterpolateGrad(*output_grad, input_grad, ratio_h, - ratio_w, n, c, out_h, out_w); + ratio_w, n, c, out_h, out_w, + align_corners); } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 56971cff43..93e77dc113 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -913,7 +913,7 @@ def dynamic_gru(input, create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. bias_attr (ParamAttr|bool|None): The parameter attribute for the bias - of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates + of GRU.Note that the bias with :math:`(1 \\times 3D)` concatenates the bias in the update gate, reset gate and candidate calculations. If it is set to False, no bias will be applied to the update gate, reset gate and candidate calculations. If it is set to None or one @@ -1034,7 +1034,7 @@ def gru_unit(input, create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. bias_attr (ParamAttr|bool|None): The parameter attribute for the bias - of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates + of GRU.Note that the bias with :math:`(1 \\times 3D)` concatenates the bias in the update gate, reset gate and candidate calculations. If it is set to False, no bias will be applied to the update gate, reset gate and candidate calculations. If it is set to None or one @@ -5350,7 +5350,7 @@ def transpose(x, perm, name=None): Examples: .. code-block:: python - # use append_batch_size=False to avoid prepending extra + # use append_batch_size=False to avoid prepending extra # batch size in shape x = fluid.layers.data(name='x', shape=[5, 10, 15], dtype='float32', append_batch_size=False) @@ -5866,7 +5866,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): than :attr:`shape`. act (str): The non-linear activation to be applied to the reshaped tensor variable. - inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple + inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple operators. If this flag is set :attr:`True`, reuse input :attr:`x` to reshape, which will change the shape of tensor variable :attr:`x` and might cause errors when @@ -6527,7 +6527,9 @@ def image_resize(input, scale=None, name=None, resample='BILINEAR', - actual_shape=None): + actual_shape=None, + align_corners=True, + align_mode=0): """ **Resize a Batch of Images** @@ -6540,6 +6542,83 @@ def image_resize(input, 'NEAREST' : Nearest neighbor interpolation + Nearest neighbor interpolation is to perform nearest neighbor interpolation + in both the 3rd dimention(in height direction) and the 4th dimention(in width + direction) on input tensor. + + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then + again in the other direction. + + Align_corners and align_mode are optinal parameters,The calculation method + of interpolation can be selected by them. + + Example: + + for scale: + + if align_corners = True && out_size > 1 : + + scale_factor = (in_size-1.0)/(out_size-1.0) + + else: + + scale_factor = float(in_size/out_size) + + + Nearest neighbor interpolation: + + case 1: + align_corners = False + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor + W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor + + case 2: + align_corners = True + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = round(H_{in} * scale_{factor}) + W_out = round(W_{in} * scale_{factor}) + + Bilinear interpolation: + + case 1: + align_corners = False , align_mode = 0 + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + + case 2: + align_corners = False , align_mode = 1 + or + align_corners = True + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + For details of nearest neighbor interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. + + For details of bilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bilinear_interpolation. + + + Args: input (Variable): The input tensor of image resize layer, This is a 4-D tensor of the shape @@ -6569,6 +6648,12 @@ def image_resize(input, set, otherwise errors would be occured in graph constructing stage. Default: None + align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the + input and output tensors are aligned, preserving the values at the + corner pixels. + Default: True + align_mode(int) : An optional input to specify align_corners mode. can be \'0\' + for pytorch calculation method, can be \'1'\ for tensorflow calculation method. Returns: Variable: The output is a 4-D tensor of the shape @@ -6581,6 +6666,8 @@ def image_resize(input, or 'NEAREST' currently. ValueError: One of out_shape and scale must not be None. ValueError: out_shape length should be 2. + TypeError: align_corners shoule be a bool value + ValueError: align_mode can only be '0' or '1' Examples: .. code-block:: python @@ -6596,6 +6683,12 @@ def image_resize(input, "The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently." ) resample_type = resample_methods[resample] + + if not isinstance(align_corners, bool): + raise TypeError("Attr align_corners should be a bool value") + if align_mode != 0 and align_mode != 1: + raise ValueError("align_mode can only be 0 or 1") + if out_shape is None and scale is None: raise ValueError("One of out_shape and scale must not be None.") helper = LayerHelper('{}_interp'.format(resample_type), **locals()) @@ -6635,9 +6728,13 @@ def image_resize(input, type='{}_interp'.format(resample_type), inputs=inputs, outputs={"Out": out}, - attrs={"out_h": out_h, - "out_w": out_w, - "interp_method": resample_type}) + attrs={ + "out_h": out_h, + "out_w": out_w, + "interp_method": resample_type, + "align_corners": align_corners, + "align_mode": align_mode + }) return out @@ -6646,7 +6743,9 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None, - actual_shape=None): + actual_shape=None, + align_corners=True, + align_mode=0): """ Resize input by performing bilinear interpolation based on given output shape which specified by actual_shape, out_shape and scale @@ -6661,6 +6760,50 @@ def resize_bilinear(input, For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation + Align_corners and align_mode are optinal parameters,The calculation + method of interpolation can be selected by them. + + + Align_corners and align_mode are optinal parameters,The calculation method + of interpolation can be selected by them. + + Example: + + for scale: + + if align_corners = True && out_size > 1 : + + scale_factor = (in_size-1.0)/(out_size-1.0) + + else: + + scale_factor = float(in_size/out_size) + + Bilinear interpolation: + + case 1: + align_corners = False , align_mode = 0 + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + + case 2: + align_corners = False , align_mode = 1 + or + align_corners = True + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + + Args: input(${x_type}): ${x_comment}. @@ -6684,6 +6827,8 @@ def resize_bilinear(input, set, otherwise errors would be occured in graph constructing stage. Default: None + align_corners(bool): ${align_corners_comment} + align_mode(bool): ${align_mode_comment} Returns: ${out_comment}. @@ -6694,7 +6839,8 @@ def resize_bilinear(input, out = fluid.layers.resize_bilinear(input, out_shape=[12, 12]) """ - return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape) + return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape, + align_corners, align_mode) @templatedoc(op_type="nearest_interp") @@ -6702,13 +6848,48 @@ def resize_nearest(input, out_shape=None, scale=None, name=None, - actual_shape=None): + actual_shape=None, + align_corners=True): """ Resize input by performing nearest neighbor interpolation in both the 3rd dimention(in height direction) and the 4th dimention(in width direction) based on given output shape which specified by actual_shape, out_shape and scale in priority order. + Example: + + for scale: + + if align_corners = True && out_size > 1 : + + scale_factor = (in_size-1.0)/(out_size-1.0) + + else: + + scale_factor = float(in_size/out_size) + + + Nearest neighbor interpolation: + + case 1: + align_corners = False + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor + W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor + + case 2: + align_corners = True + + input : (N,C,H_in,W_in) + output: (N,C,H_out,W_out) where: + + H_out = round(H_{in} * scale_{factor}) + W_out = round(W_{in} * scale_{factor}) + + For details of nearest neighbor interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation @@ -6735,6 +6916,7 @@ def resize_nearest(input, set, otherwise errors would be occured in graph constructing stage. Default: None + align_corners(bool): ${align_corners_comment} Returns: ${out_comment}. @@ -6745,7 +6927,8 @@ def resize_nearest(input, out = fluid.layers.resize_nearest(input, out_shape=[12, 12]) """ - return image_resize(input, out_shape, scale, name, 'NEAREST', actual_shape) + return image_resize(input, out_shape, scale, name, 'NEAREST', actual_shape, + align_corners) def image_resize_short(input, out_short_len, resample='BILINEAR'): diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index c8a7063dc1..4523fb54ce 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -20,7 +20,13 @@ from op_test import OpTest import paddle.fluid.core as core -def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None): +def bilinear_interp_np(input, + out_h, + out_w, + out_size=None, + actual_shape=None, + align_corners=True, + align_mode=0): """bilinear interpolation implement in shape [N, C, H, W]""" if out_size is not None: out_h = out_size[0] @@ -29,25 +35,41 @@ def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None): out_h = actual_shape[0] out_w = actual_shape[1] batch_size, channel, in_h, in_w = input.shape - if out_h > 1: + + ratio_h = ratio_w = 0.0 + if (align_corners and out_h > 1): ratio_h = (in_h - 1.0) / (out_h - 1.0) else: - ratio_h = 0.0 - if out_w > 1: + ratio_h = 1.0 * in_h / out_h + if (align_corners and out_w > 1): ratio_w = (in_w - 1.0) / (out_w - 1.0) else: - ratio_w = 0.0 + ratio_w = 1.0 * in_w / out_w out = np.zeros((batch_size, channel, out_h, out_w)) + for i in range(out_h): - h = int(ratio_h * i) + if (align_mode == 0 and not align_corners): + h = int(ratio_h * (i + 0.5) - 0.5) + else: + h = int(ratio_h * i) + hid = 1 if h < in_h - 1 else 0 - h1lambda = ratio_h * i - h + if (align_mode == 0 and not align_corners): + h1lambda = ratio_h * (i + 0.5) - 0.5 - h + else: + h1lambda = ratio_h * i - h h2lambda = 1.0 - h1lambda for j in range(out_w): - w = int(ratio_w * j) + if (align_mode == 0 and not align_corners): + w = int(ratio_w * (j + 0.5) - 0.5) + else: + w = int(ratio_w * j) wid = 1 if w < in_w - 1 else 0 - w1lambda = ratio_w * j - w + if (align_mode == 0 and not align_corners): + w1lambda = ratio_w * (j + 0.5) - 0.5 - w + else: + w1lambda = ratio_w * j - w w2lambda = 1.0 - w1lambda out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + @@ -66,7 +88,8 @@ class TestBilinearInterpOp(OpTest): input_np = np.random.random(self.input_shape).astype("float32") output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, - self.out_size, self.actual_shape) + self.out_size, self.actual_shape, + self.align_corners, self.align_mode) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size @@ -75,7 +98,9 @@ class TestBilinearInterpOp(OpTest): self.attrs = { 'out_h': self.out_h, 'out_w': self.out_w, - 'interp_method': self.interp_method + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode } self.outputs = {'Out': output_np} @@ -91,6 +116,8 @@ class TestBilinearInterpOp(OpTest): self.out_h = 2 self.out_w = 2 self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase1(TestBilinearInterpOp): @@ -99,6 +126,8 @@ class TestBilinearInterpCase1(TestBilinearInterpOp): self.input_shape = [4, 1, 7, 8] self.out_h = 1 self.out_w = 1 + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase2(TestBilinearInterpOp): @@ -107,6 +136,8 @@ class TestBilinearInterpCase2(TestBilinearInterpOp): self.input_shape = [3, 3, 9, 6] self.out_h = 12 self.out_w = 12 + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase3(TestBilinearInterpOp): @@ -115,6 +146,8 @@ class TestBilinearInterpCase3(TestBilinearInterpOp): self.input_shape = [1, 1, 128, 64] self.out_h = 64 self.out_w = 128 + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase4(TestBilinearInterpOp): @@ -124,6 +157,8 @@ class TestBilinearInterpCase4(TestBilinearInterpOp): self.out_h = 1 self.out_w = 1 self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase5(TestBilinearInterpOp): @@ -133,6 +168,8 @@ class TestBilinearInterpCase5(TestBilinearInterpOp): self.out_h = 12 self.out_w = 12 self.out_size = np.array([11, 11]).astype("int32") + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase6(TestBilinearInterpOp): @@ -142,6 +179,8 @@ class TestBilinearInterpCase6(TestBilinearInterpOp): self.out_h = 64 self.out_w = 128 self.out_size = np.array([65, 129]).astype("int32") + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpActualShape(TestBilinearInterpOp): @@ -151,6 +190,8 @@ class TestBilinearInterpActualShape(TestBilinearInterpOp): self.out_h = 64 self.out_w = 32 self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpOpUint8(OpTest): @@ -162,14 +203,17 @@ class TestBilinearInterpOpUint8(OpTest): input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, - self.out_size, self.actual_shape) + self.out_size, self.actual_shape, + self.align_corners, self.align_mode) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size self.attrs = { 'out_h': self.out_h, 'out_w': self.out_w, - 'interp_method': self.interp_method + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode } self.outputs = {'Out': output_np} @@ -181,6 +225,8 @@ class TestBilinearInterpOpUint8(OpTest): self.input_shape = [1, 3, 9, 6] self.out_h = 10 self.out_w = 9 + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8): @@ -189,6 +235,8 @@ class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8): self.input_shape = [2, 3, 128, 64] self.out_h = 120 self.out_w = 50 + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8): @@ -198,6 +246,26 @@ class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8): self.out_h = 5 self.out_w = 13 self.out_size = np.array([6, 15]).astype("int32") + self.align_corners = False + self.align_mode = 0 + + +class TestBilinearInterpOtherMethod1(TestBilinearInterpOp): + def set_align_mode(self): + self.align_mode = 1 + self.align_corners = False + + +class TestBilinearInterpWithMethod2(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = True + self.align_mode = 1 + + +class TestBilinearInterpWithMethod3(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = True + self.align_mode = 0 if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py index 242709425f..22f7bac0be 100644 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py @@ -24,7 +24,8 @@ def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None, - actual_shape=None): + actual_shape=None, + align_corners=True): """nearest neighbor interpolation implement in shape [N, C, H, W]""" if out_size is not None: out_h = out_size[0] @@ -35,17 +36,29 @@ def nearest_neighbor_interp_np(X, n, c, in_h, in_w = X.shape ratio_h = ratio_w = 0.0 - if out_h > 1: + if (align_corners and out_h > 1): ratio_h = (in_h - 1.0) / (out_h - 1.0) - if out_w > 1: + else: + ratio_h = 1.0 * in_h / out_h + if (align_corners and out_w > 1): ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w out = np.zeros((n, c, out_h, out_w)) - for i in range(out_h): - in_i = int(ratio_h * i + 0.5) - for j in range(out_w): - in_j = int(ratio_w * j + 0.5) - out[:, :, i, j] = X[:, :, in_i, in_j] + + if align_corners: + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, i, j] = X[:, :, in_i, in_j] + else: + for i in range(out_h): + in_i = int(ratio_h * i) + for j in range(out_w): + in_j = int(ratio_w * j) + out[:, :, i, j] = X[:, :, in_i, in_j] return out.astype(X.dtype) @@ -59,7 +72,8 @@ class TestNearestInterpOp(OpTest): input_np = np.random.random(self.input_shape).astype("float32") output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, - self.out_size, self.actual_shape) + self.out_size, self.actual_shape, + self.align_corners) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size @@ -68,7 +82,8 @@ class TestNearestInterpOp(OpTest): self.attrs = { 'out_h': self.out_h, 'out_w': self.out_w, - 'interp_method': self.interp_method + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, } self.outputs = {'Out': output_np} @@ -84,6 +99,7 @@ class TestNearestInterpOp(OpTest): self.out_h = 2 self.out_w = 2 self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True class TestNearestNeighborInterpCase1(TestNearestInterpOp): @@ -92,6 +108,7 @@ class TestNearestNeighborInterpCase1(TestNearestInterpOp): self.input_shape = [4, 1, 7, 8] self.out_h = 1 self.out_w = 1 + self.align_corners = False class TestNearestNeighborInterpCase2(TestNearestInterpOp): @@ -100,6 +117,7 @@ class TestNearestNeighborInterpCase2(TestNearestInterpOp): self.input_shape = [3, 3, 9, 6] self.out_h = 12 self.out_w = 12 + self.align_corners = True class TestNearestNeighborInterpCase3(TestNearestInterpOp): @@ -108,6 +126,7 @@ class TestNearestNeighborInterpCase3(TestNearestInterpOp): self.input_shape = [1, 1, 128, 64] self.out_h = 64 self.out_w = 128 + self.align_corners = True class TestNearestNeighborInterpCase4(TestNearestInterpOp): @@ -117,6 +136,7 @@ class TestNearestNeighborInterpCase4(TestNearestInterpOp): self.out_h = 1 self.out_w = 1 self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = True class TestNearestNeighborInterpCase5(TestNearestInterpOp): @@ -126,6 +146,7 @@ class TestNearestNeighborInterpCase5(TestNearestInterpOp): self.out_h = 12 self.out_w = 12 self.out_size = np.array([11, 11]).astype("int32") + self.align_corners = True class TestNearestNeighborInterpCase6(TestNearestInterpOp): @@ -135,6 +156,7 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp): self.out_h = 64 self.out_w = 128 self.out_size = np.array([65, 129]).astype("int32") + self.align_corners = True class TestNearestNeighborInterpActualShape(TestNearestInterpOp): @@ -144,6 +166,7 @@ class TestNearestNeighborInterpActualShape(TestNearestInterpOp): self.out_h = 64 self.out_w = 32 self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True class TestNearestInterpOpUint8(OpTest): @@ -155,14 +178,16 @@ class TestNearestInterpOpUint8(OpTest): input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, - self.out_size, self.actual_shape) + self.out_size, self.actual_shape, + self.align_corners) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size self.attrs = { 'out_h': self.out_h, 'out_w': self.out_w, - 'interp_method': self.interp_method + 'interp_method': self.interp_method, + 'align_corners': self.align_corners } self.outputs = {'Out': output_np} @@ -174,6 +199,7 @@ class TestNearestInterpOpUint8(OpTest): self.input_shape = [1, 3, 9, 6] self.out_h = 10 self.out_w = 9 + self.align_corners = True class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8): @@ -182,6 +208,7 @@ class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8): self.input_shape = [2, 3, 128, 64] self.out_h = 120 self.out_w = 50 + self.align_corners = False class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8): @@ -191,6 +218,12 @@ class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8): self.out_h = 5 self.out_w = 13 self.out_size = np.array([6, 15]).astype("int32") + self.align_corners = True + + +class TestNearestInterpWithoutCorners(TestNearestInterpOp): + def set_align_corners(self): + self.align_corners = False if __name__ == "__main__": From 88744e4ab8002f7770b0f87e8b1cc9ae7469ea57 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 24 Jan 2019 13:24:34 +0800 Subject: [PATCH 09/53] fixed some errors test=develop --- paddle/fluid/API.spec | 7 +-- paddle/fluid/operators/interpolate_op.cc | 17 +++--- paddle/fluid/operators/interpolate_op.cu | 4 ++ paddle/fluid/operators/interpolate_op.h | 4 ++ python/paddle/fluid/layers/nn.py | 27 ++++------ .../unittests/test_bilinear_interp_op.py | 52 ++++++++++--------- .../tests/unittests/test_nearest_interp_op.py | 2 +- 7 files changed, 58 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6937d13dba..f4e964d8c2 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -140,10 +140,10 @@ paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) -paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None)) +paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) -paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None)) -paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None)) +paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)) +paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) @@ -505,3 +505,4 @@ paddle.reader.Fake.__init__ ArgSpec(args=['self'], varargs=None, keywords=None, paddle.reader.creator.np_array ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.reader.creator.text_file ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None) paddle.reader.creator.recordio ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)) + diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 1b34d404c0..13be33a391 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -90,10 +90,10 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { "if Flase, are not aligned") .SetDefault(true); AddAttr("align_mode", - "(int, default \'0\'), align_corners mode , can be \'0\' " - "for pytorch calculation method, can be \'1\' for " - "tensorflow calculation method.") - .SetDefault(0); + "(int, default \'1\'), can be \'0\' for " + "src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for " + "src_idx = scale*dst_index .") + .SetDefault(1); AddComment(R"DOC( This operator samples input X to given output shape by using specified interpolation method, the interpolation methods can be \"nearest\" @@ -115,7 +115,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { Example: - for scale: + For scale: if align_corners = True and out_{size}>1 : @@ -148,7 +148,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { Bilinear interpolation: - case 1: + if: align_corners = False , align_mode = 0 input : (N,C,H_in,W_in) @@ -158,10 +158,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { W_out = (W_{in}+0.5) * scale_{factor} - 0.5 - case 2: - align_corners = False , align_mode = 1 - or - align_corners = True + else: input : (N,C,H_in,W_in) output: (N,C,H_out,W_out) where: diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 316811d23e..7595511cf5 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -105,6 +105,7 @@ __global__ void KeBilinearInterpFw( int in_img_idy = (align_mode == 0 && !align_corners) ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) : static_cast(ratio_h * out_img_idy); + in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; T h1lambda = (align_mode == 0 && !align_corners) ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy @@ -115,6 +116,7 @@ __global__ void KeBilinearInterpFw( int in_img_idx = (align_mode == 0 && !align_corners) ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) : static_cast(ratio_w * out_img_idx); + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; T w1lambda = (align_mode == 0 && !align_corners) ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx @@ -153,6 +155,7 @@ __global__ void KeBilinearInterpBw( int in_img_idy = (align_mode == 0 && !align_corners) ? ratio_h * (out_img_idy + 0.5) - 0.5 : ratio_h * out_img_idy; + in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; T h1lambda = (align_mode == 0 && !align_corners) ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy @@ -164,6 +167,7 @@ __global__ void KeBilinearInterpBw( int in_img_idx = (align_mode == 0 && !align_corners) ? ratio_w * (out_img_idx + 0.5) - 0.5 : ratio_w * out_img_idx; + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; T w1lambda = (align_mode == 0 && !align_corners) ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index 95aec33eee..ab41ff781a 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -60,6 +60,7 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, int y_n = (align_mode == 0 && !align_corners) ? static_cast(ratio_h * (k + 0.5) - 0.5) : static_cast(ratio_h * k); + y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); float d_n = (align_mode == 0 && !align_corners) ? ratio_h * (k + 0.5) - 0.5 - y_n @@ -70,6 +71,7 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, int x_w = (align_mode == 0 && !align_corners) ? static_cast(ratio_w * (l + 0.5) - 0.5) : static_cast(ratio_w * l); + x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); float d_w = (align_mode == 0 && !align_corners) ? ratio_w * (l + 0.5) - 0.5 - x_w @@ -128,6 +130,7 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, int y_n = (align_mode == 0 && !align_corners) ? static_cast(ratio_h * (k + 0.5) - 0.5) : static_cast(ratio_h * k); + y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); float d_n = (align_mode == 0 && !align_corners) ? ratio_h * (k + 0.5) - 0.5 - y_n @@ -138,6 +141,7 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, int x_w = (align_mode == 0 && !align_corners) ? static_cast(ratio_w * (l + 0.5) - 0.5) : static_cast(ratio_w * l); + x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); float d_w = (align_mode == 0 && !align_corners) ? ratio_w * (l + 0.5) - 0.5 - x_w diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 93e77dc113..765fa8565b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6557,7 +6557,7 @@ def image_resize(input, Example: - for scale: + For scale: if align_corners = True && out_size > 1 : @@ -6590,7 +6590,7 @@ def image_resize(input, Bilinear interpolation: - case 1: + if: align_corners = False , align_mode = 0 input : (N,C,H_in,W_in) @@ -6600,10 +6600,7 @@ def image_resize(input, W_out = (W_{in}+0.5) * scale_{factor} - 0.5 - case 2: - align_corners = False , align_mode = 1 - or - align_corners = True + else: input : (N,C,H_in,W_in) output: (N,C,H_out,W_out) where: @@ -6652,8 +6649,9 @@ def image_resize(input, input and output tensors are aligned, preserving the values at the corner pixels. Default: True - align_mode(int) : An optional input to specify align_corners mode. can be \'0\' - for pytorch calculation method, can be \'1'\ for tensorflow calculation method. + align_mode(int) : An optional input to specify src_idx calculation. can be \'0\' + for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for + src_idx = scale*dst_index . Returns: Variable: The output is a 4-D tensor of the shape @@ -6769,7 +6767,7 @@ def resize_bilinear(input, Example: - for scale: + For scale: if align_corners = True && out_size > 1 : @@ -6781,7 +6779,7 @@ def resize_bilinear(input, Bilinear interpolation: - case 1: + if: align_corners = False , align_mode = 0 input : (N,C,H_in,W_in) @@ -6791,11 +6789,8 @@ def resize_bilinear(input, W_out = (W_{in}+0.5) * scale_{factor} - 0.5 - case 2: - align_corners = False , align_mode = 1 - or - align_corners = True - + else: + input : (N,C,H_in,W_in) output: (N,C,H_out,W_out) where: @@ -6858,7 +6853,7 @@ def resize_nearest(input, Example: - for scale: + For scale: if align_corners = True && out_size > 1 : diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index 4523fb54ce..2e3de58a3a 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -54,6 +54,7 @@ def bilinear_interp_np(input, else: h = int(ratio_h * i) + h = max(0, h) hid = 1 if h < in_h - 1 else 0 if (align_mode == 0 and not align_corners): h1lambda = ratio_h * (i + 0.5) - 0.5 - h @@ -65,6 +66,7 @@ def bilinear_interp_np(input, w = int(ratio_w * (j + 0.5) - 0.5) else: w = int(ratio_w * j) + w = max(0, w) wid = 1 if w < in_w - 1 else 0 if (align_mode == 0 and not align_corners): w1lambda = ratio_w * (j + 0.5) - 0.5 - w @@ -116,8 +118,8 @@ class TestBilinearInterpOp(OpTest): self.out_h = 2 self.out_w = 2 self.out_size = np.array([3, 3]).astype("int32") - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase1(TestBilinearInterpOp): @@ -126,8 +128,8 @@ class TestBilinearInterpCase1(TestBilinearInterpOp): self.input_shape = [4, 1, 7, 8] self.out_h = 1 self.out_w = 1 - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase2(TestBilinearInterpOp): @@ -136,8 +138,8 @@ class TestBilinearInterpCase2(TestBilinearInterpOp): self.input_shape = [3, 3, 9, 6] self.out_h = 12 self.out_w = 12 - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase3(TestBilinearInterpOp): @@ -146,8 +148,8 @@ class TestBilinearInterpCase3(TestBilinearInterpOp): self.input_shape = [1, 1, 128, 64] self.out_h = 64 self.out_w = 128 - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase4(TestBilinearInterpOp): @@ -157,8 +159,8 @@ class TestBilinearInterpCase4(TestBilinearInterpOp): self.out_h = 1 self.out_w = 1 self.out_size = np.array([2, 2]).astype("int32") - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase5(TestBilinearInterpOp): @@ -168,8 +170,8 @@ class TestBilinearInterpCase5(TestBilinearInterpOp): self.out_h = 12 self.out_w = 12 self.out_size = np.array([11, 11]).astype("int32") - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase6(TestBilinearInterpOp): @@ -179,8 +181,8 @@ class TestBilinearInterpCase6(TestBilinearInterpOp): self.out_h = 64 self.out_w = 128 self.out_size = np.array([65, 129]).astype("int32") - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpActualShape(TestBilinearInterpOp): @@ -190,8 +192,8 @@ class TestBilinearInterpActualShape(TestBilinearInterpOp): self.out_h = 64 self.out_w = 32 self.out_size = np.array([66, 40]).astype("int32") - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpOpUint8(OpTest): @@ -225,8 +227,8 @@ class TestBilinearInterpOpUint8(OpTest): self.input_shape = [1, 3, 9, 6] self.out_h = 10 self.out_w = 9 - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8): @@ -235,8 +237,8 @@ class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8): self.input_shape = [2, 3, 128, 64] self.out_h = 120 self.out_w = 50 - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8): @@ -246,20 +248,20 @@ class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8): self.out_h = 5 self.out_w = 13 self.out_size = np.array([6, 15]).astype("int32") - self.align_corners = False - self.align_mode = 0 + self.align_corners = True + self.align_mode = 1 class TestBilinearInterpOtherMethod1(TestBilinearInterpOp): def set_align_mode(self): - self.align_mode = 1 self.align_corners = False + self.align_mode = 1 class TestBilinearInterpWithMethod2(TestBilinearInterpOp): def set_align_mode(self): - self.align_corners = True - self.align_mode = 1 + self.align_corners = False + self.align_mode = 0 class TestBilinearInterpWithMethod3(TestBilinearInterpOp): diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py index 22f7bac0be..c97aa886a9 100644 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py @@ -108,7 +108,7 @@ class TestNearestNeighborInterpCase1(TestNearestInterpOp): self.input_shape = [4, 1, 7, 8] self.out_h = 1 self.out_w = 1 - self.align_corners = False + self.align_corners = True class TestNearestNeighborInterpCase2(TestNearestInterpOp): From e448bdb298aa8f32c398f9dfc2bd215e4fce6d56 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 24 Jan 2019 13:35:54 +0800 Subject: [PATCH 10/53] modified some comments test=develop --- paddle/fluid/operators/interpolate_op.cc | 4 ++-- python/paddle/fluid/layers/nn.py | 8 ++++---- .../fluid/tests/unittests/test_nearest_interp_op.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 13be33a391..83b2086bbb 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -128,7 +128,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { Nearest neighbor interpolation: - case 1: + if: align_corners = False input : (N,C,H_in,W_in) @@ -137,7 +137,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor - case 2: + else: align_corners = True input : (N,C,H_in,W_in) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 765fa8565b..4d40f2e7c2 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6570,7 +6570,7 @@ def image_resize(input, Nearest neighbor interpolation: - case 1: + if: align_corners = False input : (N,C,H_in,W_in) @@ -6579,7 +6579,7 @@ def image_resize(input, H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor - case 2: + else: align_corners = True input : (N,C,H_in,W_in) @@ -6866,7 +6866,7 @@ def resize_nearest(input, Nearest neighbor interpolation: - case 1: + if: align_corners = False input : (N,C,H_in,W_in) @@ -6875,7 +6875,7 @@ def resize_nearest(input, H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor - case 2: + else: align_corners = True input : (N,C,H_in,W_in) diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py index c97aa886a9..9984a793ca 100644 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py @@ -208,7 +208,7 @@ class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8): self.input_shape = [2, 3, 128, 64] self.out_h = 120 self.out_w = 50 - self.align_corners = False + self.align_corners = True class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8): From 78145c7dff12b0bfb181a0217b42ca2c261bb268 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 24 Jan 2019 17:48:56 +0800 Subject: [PATCH 11/53] modified some comments test=develop --- paddle/fluid/operators/interpolate_op.cc | 6 +++--- python/paddle/fluid/layers/nn.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 83b2086bbb..357832223c 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -90,9 +90,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { "if Flase, are not aligned") .SetDefault(true); AddAttr("align_mode", - "(int, default \'1\'), can be \'0\' for " - "src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for " - "src_idx = scale*dst_index .") + "(int, default \'1\'), optional for bilinear interpolation" + "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , " + "can be \'1\' for src_idx = scale*dst_index .") .SetDefault(1); AddComment(R"DOC( This operator samples input X to given output shape by using specified diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4d40f2e7c2..77545d6002 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6649,7 +6649,7 @@ def image_resize(input, input and output tensors are aligned, preserving the values at the corner pixels. Default: True - align_mode(int) : An optional input to specify src_idx calculation. can be \'0\' + align_mode(int) : An optional for bilinear interpolation. can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for src_idx = scale*dst_index . From a39240c3b6af17b05e5a55bf8bbb199775498696 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Fri, 25 Jan 2019 07:46:48 +0000 Subject: [PATCH 12/53] add attr variance for box coder, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 7 + .../fluid/operators/detection/box_coder_op.cu | 59 +++++--- .../fluid/operators/detection/box_coder_op.h | 38 +++++- python/paddle/fluid/layers/detection.py | 126 +++++++++++++++--- python/paddle/fluid/tests/test_detection.py | 2 +- .../tests/unittests/test_box_coder_op.py | 57 ++++++-- 6 files changed, 236 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index f89f87663b..fdcff62e1f 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detection/box_coder_op.h" +#include namespace paddle { namespace operators { @@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "when code type is decode_center_size") .SetDefault(0) .InEnum({0, 1}); + AddAttr>( + "variance", + "(vector, default {})," + "variance of prior box with shape [4]. PriorBoxVar and variance can" + "not be provided at the same time.") + .SetDefault(std::vector{}); AddOutput("OutputBox", "(LoDTensor or Tensor) " "When code_type is 'encode_center_size', the output tensor of " diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index 0b64224e1e..9b73572274 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -16,12 +18,11 @@ namespace paddle { namespace operators { template -__global__ void EncodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, const int row, - const int col, const int len, - const bool normalized, - const T prior_box_var_size, T* output) { +__global__ void EncodeCenterSizeKernel( + const T* prior_box_data, const T* prior_box_var_data, + const T* target_box_data, const int row, const int col, const int len, + const bool normalized, const T prior_box_var_size, const float* variance, + const int var_size, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; @@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; + } else if (var_size == 4) { + for (int k = 0; k < 4; ++k) { + output[idx * len + k] /= static_cast(variance[k]); + } } } } template -__global__ void DecodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, const int row, - const int col, const int len, - const bool normalized, - const T prior_box_var_size, - const int axis, T* output) { +__global__ void DecodeCenterSizeKernel( + const T* prior_box_data, const T* prior_box_var_data, + const T* target_box_data, const int row, const int col, const int len, + const bool normalized, const T prior_box_var_size, const float* variance, + const int var_size, const int axis, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; int prior_box_offset = 0; if (idx < row * col) { @@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, target_box_data[idx * len + 1] * prior_box_height + prior_box_center_y; + } else if (var_size == 4) { + target_box_width = + exp(static_cast(variance[2]) * target_box_data[idx * len + 2]) * + prior_box_width; + target_box_height = + exp(static_cast(variance[3]) * target_box_data[idx * len + 3]) * + prior_box_height; + target_box_center_x = static_cast(variance[0]) * + target_box_data[idx * len] * prior_box_width + + prior_box_center_x; + target_box_center_y = static_cast(variance[1]) * + target_box_data[idx * len + 1] * + prior_box_height + + prior_box_center_y; } else { target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width; target_box_height = @@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); - + std::vector variance = context.Attr>("variance"); const T* prior_box_data = prior_box->data(); const T* target_box_data = target_box->data(); const T* prior_box_var_data = nullptr; auto prior_box_var_size = 0; if (prior_box_var) { + PADDLE_ENFORCE(variance.empty(), + "Input 'PriorBoxVar' and attribute 'variance' should not" + "be used at the same time."); prior_box_var_data = prior_box_var->data(); prior_box_var_size = prior_box_var->dims().size(); } + if (!(variance.empty())) { + PADDLE_ENFORCE(static_cast(variance.size()) == 4, + "Size of attribute 'variance' should be 4"); + } if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } + const int var_size = static_cast(variance.size()); + thrust::device_vector dev_variance(variance.begin(), variance.end()); + const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data()); auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); int axis = context.Attr("axis"); @@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel { if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, prior_box_var_size, output); + normalized, prior_box_var_size, dev_var_data, var_size, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, prior_box_var_size, axis, output); + normalized, prior_box_var_size, dev_var_data, var_size, axis, output); } } }; diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index 986869d8a3..b61cff1b1d 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -11,6 +11,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel { void EncodeCenterSize(const framework::Tensor* target_box, const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, - const bool normalized, T* output) const { + const bool normalized, + const std::vector variance, T* output) const { int64_t row = target_box->dims()[0]; int64_t col = prior_box->dims()[0]; int64_t len = prior_box->dims()[1]; @@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel { output[offset + 1] /= prior_box_var_data[prior_var_offset + 1]; output[offset + 2] /= prior_box_var_data[prior_var_offset + 2]; output[offset + 3] /= prior_box_var_data[prior_var_offset + 3]; + } else if (!(variance.empty())) { + for (int k = 0; k < 4; ++k) { + output[offset + k] /= static_cast(variance[k]); + } } } } @@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel { const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, const bool normalized, const int axis, - T* output) const { + const std::vector variance, T* output) const { int64_t row = target_box->dims()[0]; int64_t col = target_box->dims()[1]; int64_t len = target_box->dims()[2]; @@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel { std::exp(prior_box_var_data[prior_var_offset + 3] * target_box_data[offset + 3]) * prior_box_height; + } else if (!(variance.empty())) { + target_box_center_x = static_cast(variance[0]) * + target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = static_cast(variance[1]) * + target_box_data[offset + 1] * + prior_box_height + + prior_box_center_y; + target_box_width = std::exp(static_cast(variance[2]) * + target_box_data[offset + 2]) * + prior_box_width; + target_box_height = std::exp(static_cast(variance[3]) * + target_box_data[offset + 3]) * + prior_box_height; } else { target_box_center_x = target_box_data[offset] * prior_box_width + prior_box_center_x; @@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); + std::vector variance = context.Attr>("variance"); const int axis = context.Attr("axis"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, "Only support 1 level of LoD."); } + if (prior_box_var) { + PADDLE_ENFORCE(variance.empty(), + "Input 'PriorBoxVar' and attribute 'variance' should not" + "be used at the same time."); + } + if (!(variance.empty())) { + PADDLE_ENFORCE(static_cast(variance.size()) == 4, + "Size of attribute 'variance' should be 4"); + } auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); @@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel { T* output = output_box->data(); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, - output); + variance, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis, - output); + variance, output); } } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 1eb876cfaf..854b34d2a4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -346,18 +346,104 @@ def box_coder(prior_box, name=None, axis=0): """ - ${comment} + **Box Coder Layer** + + Encode/Decode the target bounding box with the priorbox information. + + The Encoding schema described below: + + .. math:: + + ox = (tx - px) / pw / pxv + + oy = (ty - py) / ph / pyv + + ow = \log(\abs(tw / pw)) / pwv + + oh = \log(\abs(th / ph)) / phv + + The Decoding schema described below: + + .. math:: + + ox = (pw * pxv * tx * + px) - tw / 2 + + oy = (ph * pyv * ty * + py) - th / 2 + + ow = \exp(pwv * tw) * pw + tw / 2 + + oh = \exp(phv * th) * ph + th / 2 + + where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, + width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote + the priorbox's (anchor) center coordinates, width and height. `pxv`, + `pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`, + `ow`, `oh` denote the encoded/decoded coordinates, width and height. + + During Box Decoding, two modes for broadcast are supported. Say target + box has shape [N, M, 4], and the shape of prior box can be [N, 4] or + [M, 4]. Then prior box will broadcast to target box along the + assigned axis. Args: - prior_box(${prior_box_type}): ${prior_box_comment} - prior_box_var(${prior_box_var_type}): ${prior_box_var_comment} - target_box(${target_box_type}): ${target_box_comment} - code_type(${code_type_type}): ${code_type_comment} - box_normalized(${box_normalized_type}): ${box_normalized_comment} - axis(${axis_type}): ${axis_comment} + prior_box(Variable): Box list prior_box is a 2-D Tensor with shape + [M, 4] holds M boxes, each box is represented as + [xmin, ymin, xmax, ymax], [xmin, ymin] is the + left top coordinate of the anchor box, if the + input is image feature map, they are close to + the origin of the coordinate system. [xmax, ymax] + is the right bottom coordinate of the anchor box. + prior_box_var(Variable|list): prior_box_var supports two types of input. + One is variable with shape [M, 4] holds M group. + The other one is list consist of 4 elements + shared by all boxes. + target_box(Variable): This input can be a 2-D LoDTensor with shape + [N, 4] when code_type is 'encode_center_size'. + This input also can be a 3-D Tensor with shape + [N, M, 4] when code_type is 'decode_center_size'. + Each box is represented as + [xmin, ymin, xmax, ymax]. This tensor can + contain LoD information to represent a batch + of inputs. + code_type(string): The code type used with the target box. It can be + encode_center_size or decode_center_size + box_normalized(int): Whether treat the priorbox as a noramlized box. + Set true by default. + name(string): The name of box coder. + axis(int): Which axis in PriorBox to broadcast for box decode, + for example, if axis is 0 and TargetBox has shape + [N, M, 4] and PriorBox has shape [M, 4], then PriorBox + will broadcast to [N, M, 4] for decoding. It is only valid + when code type is decode_center_size. Set 0 by default. Returns: - output_box(${output_box_type}): ${output_box_comment} + output_box(Variable): When code_type is 'encode_center_size', the + output tensor of box_coder_op with shape + [N, M, 4] representing the result of N target + boxes encoded with M Prior boxes and variances. + When code_type is 'decode_center_size', + N represents the batch size and M represents + the number of deocded boxes. + + Examples: + + .. code-block:: python + + prior_box = fluid.layers.data(name='prior_box', + shape=[512, 4], + dtype='float32', + append_batch_size=False) + target_box = fluid.layers.data(name='target_box', + shape=[512,81,4], + dtype='float32', + append_batch_size=False) + output = fluid.layers.box_coder(prior_box=prior_box, + prior_box_var=[0.1,0.1,0.2,0.2], + target_box=target_box, + code_type="decode_center_size", + box_normalized=False, + axis=1) + """ helper = LayerHelper("box_coder", **locals()) @@ -368,18 +454,22 @@ def box_coder(prior_box, output_box = helper.create_variable( name=name, dtype=prior_box.dtype, persistable=False) + inputs = {"PriorBox": prior_box, "TargetBox": target_box} + attrs = { + "code_type": code_type, + "box_normalized": box_normalized, + "axis": axis + } + if isinstance(prior_box_var, Variable): + inputs['PriorBoxVar'] = prior_box_var + elif isinstance(prior_box_var, list): + attrs['variance'] = prior_box_var + else: + raise TypeError("Input variance of box_coder must be Variable or lisz") helper.append_op( type="box_coder", - inputs={ - "PriorBox": prior_box, - "PriorBoxVar": prior_box_var, - "TargetBox": target_box - }, - attrs={ - "code_type": code_type, - "box_normalized": box_normalized, - "axis": axis - }, + inputs=inputs, + attrs=attrs, outputs={"OutputBox": output_box}) return output_box diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 2d9ed9f9c6..2dbcfa31fc 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase): iou = layers.iou_similarity(x=x, y=y) bcoder = layers.box_coder( prior_box=x, - prior_box_var=y, + prior_box_var=[0.2, 0.3, 0.3, 0.2], target_box=z, code_type='encode_center_size') self.assertIsNotNone(iou) diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 6f7930c921..6156268bf2 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.random.random((10, 4)).astype('float32') - target_box = np.random.random((5, 10, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') + prior_box_var = np.random.random((81, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((6, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32') - target_box = np.random.random((3, 6, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[0, 1, 2, 3, 4, 5]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.ones((10, 4)).astype('float32') - target_box = np.random.random((5, 10, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') + prior_box_var = np.ones((81, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest): def setUp(self): self.op_type = "box_coder" - lod = [[4, 8, 8]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.random.random((10, 4)).astype('float32') - target_box = np.random.random((20, 4)).astype('float32') + lod = [[10, 20, 20]] + prior_box = np.random.random((20, 4)).astype('float32') + prior_box_var = np.random.random((20, 4)).astype('float32') + target_box = np.random.random((50, 4)).astype('float32') code_type = "EncodeCenterSize" box_normalized = True output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((5, 4)).astype('float32') + prior_box = np.random.random((30, 4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32') - target_box = np.random.random((5, 6, 4)).astype('float32') + target_box = np.random.random((30, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False axis = 1 @@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest): self.outputs = {'OutputBox': output_box} +class TestBoxCoderOpWithVariance(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[1, 1, 1, 1, 1]] + prior_box = np.random.random((30, 4)).astype('float32') + prior_box_var = np.random.random((4)).astype('float32') + target_box = np.random.random((30, 81, 4)).astype('float32') + code_type = "DecodeCenterSize" + box_normalized = False + axis = 1 + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type, box_normalized, axis) + + self.inputs = { + 'PriorBox': prior_box, + 'TargetBox': target_box, + } + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False, + 'variance': prior_box_var.astype(np.float).flatten(), + 'axis': axis + } + self.outputs = {'OutputBox': output_box} + + if __name__ == '__main__': unittest.main() From b64cdaf6dc138c45d8aa0996c7b83091257f3611 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 25 Jan 2019 00:45:56 -0800 Subject: [PATCH 13/53] modified default parameters test=develop --- python/paddle/fluid/layers/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 77545d6002..a5a3aa2f3a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6529,7 +6529,7 @@ def image_resize(input, resample='BILINEAR', actual_shape=None, align_corners=True, - align_mode=0): + align_mode=1): """ **Resize a Batch of Images** @@ -6743,7 +6743,7 @@ def resize_bilinear(input, name=None, actual_shape=None, align_corners=True, - align_mode=0): + align_mode=1): """ Resize input by performing bilinear interpolation based on given output shape which specified by actual_shape, out_shape and scale From a0c63f11069235e66d4d0d41e996631981eae5fd Mon Sep 17 00:00:00 2001 From: tink2123 Date: Sun, 27 Jan 2019 21:46:12 -0800 Subject: [PATCH 14/53] add align_flag test=develop --- paddle/fluid/operators/interpolate_op.cc | 2 +- paddle/fluid/operators/interpolate_op.cu | 36 ++++++++++------------- paddle/fluid/operators/interpolate_op.h | 37 ++++++++++-------------- python/paddle/fluid/layers/nn.py | 6 ++-- 4 files changed, 36 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 357832223c..de91ba6270 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -110,7 +110,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { to perform linear interpolation first in one direction, and then again in the other direction. - Align_corners and align_mode are optinal parameters,The calculation method + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. Example: diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 7595511cf5..1dfd4947c6 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -94,6 +94,7 @@ __global__ void KeBilinearInterpFw( int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; @@ -102,25 +103,23 @@ __global__ void KeBilinearInterpFw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = (align_mode == 0 && !align_corners) + int in_img_idy = align_flag ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) : static_cast(ratio_h * out_img_idy); in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = (align_mode == 0 && !align_corners) - ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy - : ratio_h * out_img_idy - in_img_idy; + T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy + : ratio_h * out_img_idy - in_img_idy; T h2lambda = 1.f - h1lambda; int out_img_idx = tid % out_img_w; - int in_img_idx = (align_mode == 0 && !align_corners) + int in_img_idx = align_flag ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) : static_cast(ratio_w * out_img_idx); in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = (align_mode == 0 && !align_corners) - ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx - : ratio_w * out_img_idx - in_img_idx; + T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx + : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + @@ -144,6 +143,7 @@ __global__ void KeBilinearInterpBw( int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); for (; tid < nthreads; tid += stride) { int out_id_h = tid / output_w; int out_id_w = tid % output_w; @@ -152,26 +152,22 @@ __global__ void KeBilinearInterpBw( int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_img_w; - int in_img_idy = (align_mode == 0 && !align_corners) - ? ratio_h * (out_img_idy + 0.5) - 0.5 - : ratio_h * out_img_idy; + int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 + : ratio_h * out_img_idy; in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T h1lambda = (align_mode == 0 && !align_corners) - ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy - : ratio_h * out_img_idy - in_img_idy; + T h1lambda = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy + : ratio_h * out_img_idy - in_img_idy; T h2lambda = 1.f - h1lambda; int out_img_idx = tid % out_img_w; - int in_img_idx = (align_mode == 0 && !align_corners) - ? ratio_w * (out_img_idx + 0.5) - 0.5 - : ratio_w * out_img_idx; + int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 + : ratio_w * out_img_idx; in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T w1lambda = (align_mode == 0 && !align_corners) - ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx - : ratio_w * out_img_idx - in_img_idx; + T w1lambda = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx + : ratio_w * out_img_idx - in_img_idx; T w2lambda = 1.f - w1lambda; T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index ab41ff781a..1ec0cb5025 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -56,15 +56,14 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, const bool align_mode) { auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); + bool align_flag = (align_mode == 0 && !align_corners); for (int k = 0; k < out_h; k++) { // loop for images - int y_n = (align_mode == 0 && !align_corners) - ? static_cast(ratio_h * (k + 0.5) - 0.5) - : static_cast(ratio_h * k); + int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = (align_mode == 0 && !align_corners) - ? ratio_h * (k + 0.5) - 0.5 - y_n - : ratio_h * k - y_n; + float d_n = + align_flag ? ratio_h * (k + 0.5) - 0.5 - y_n : ratio_h * k - y_n; float d_s = 1.f - d_n; for (int l = 0; l < out_w; l++) { @@ -73,9 +72,8 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, : static_cast(ratio_w * l); x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = (align_mode == 0 && !align_corners) - ? ratio_w * (l + 0.5) - 0.5 - x_w - : ratio_w * l - x_w; + float d_w = + align_flag ? ratio_w * (l + 0.5) - 0.5 - x_w : ratio_w * l - x_w; float d_e = 1.f - d_w; for (int i = 0; i < n; i++) { // loop for batches @@ -126,26 +124,23 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, const int align_mode) { auto input_grad_t = EigenTensor::From(*input_grad); auto output_grad_t = EigenTensor::From(output_grad); + bool align_flag = (align_mode == 0 && !align_corners); for (int k = 0; k < out_h; k++) { // loop for images - int y_n = (align_mode == 0 && !align_corners) - ? static_cast(ratio_h * (k + 0.5) - 0.5) - : static_cast(ratio_h * k); + int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); y_n = (y_n > 0) ? y_n : 0; int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); - float d_n = (align_mode == 0 && !align_corners) - ? ratio_h * (k + 0.5) - 0.5 - y_n - : ratio_h * k - y_n; + float d_n = + align_flag ? ratio_h * (k + 0.5) - 0.5 - y_n : ratio_h * k - y_n; float d_s = 1.f - d_n; for (int l = 0; l < out_w; l++) { - int x_w = (align_mode == 0 && !align_corners) - ? static_cast(ratio_w * (l + 0.5) - 0.5) - : static_cast(ratio_w * l); + int x_w = align_flag ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); x_w = (x_w > 0) ? x_w : 0; int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); - float d_w = (align_mode == 0 && !align_corners) - ? ratio_w * (l + 0.5) - 0.5 - x_w - : ratio_w * l - x_w; + float d_w = + align_flag ? ratio_w * (l + 0.5) - 0.5 - x_w : ratio_w * l - x_w; float d_e = 1.f - d_w; for (int i = 0; i < n; i++) { // loop for batches diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a5a3aa2f3a..b398f5d206 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6552,7 +6552,7 @@ def image_resize(input, to perform linear interpolation first in one direction, and then again in the other direction. - Align_corners and align_mode are optinal parameters,The calculation method + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. Example: @@ -6758,11 +6758,11 @@ def resize_bilinear(input, For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation - Align_corners and align_mode are optinal parameters,The calculation + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. - Align_corners and align_mode are optinal parameters,The calculation method + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. Example: From cee2e1b089f88d9a8dca530c197cb246a628e4b7 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 28 Jan 2019 05:57:33 +0000 Subject: [PATCH 15/53] refine code, test=develop --- .../fluid/operators/detection/box_coder_op.cu | 70 +++++++++---------- .../fluid/operators/detection/box_coder_op.h | 56 ++++++--------- python/paddle/fluid/tests/test_detection.py | 15 +++- 3 files changed, 67 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index 9b73572274..e078af3eb4 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -11,6 +11,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -95,47 +96,33 @@ __global__ void DecodeCenterSizeKernel( prior_box_data[prior_box_offset + 1] + prior_box_height / 2; T target_box_width, target_box_height; T target_box_center_x, target_box_center_y; + T box_var_x = T(1), box_var_y = T(1); + T box_var_w = T(1), box_var_h = T(1); if (prior_box_var_data) { int prior_var_offset = 0; if (prior_box_var_size == 2) { prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; } - target_box_width = exp(prior_box_var_data[prior_var_offset + 2] * - target_box_data[idx * len + 2]) * - prior_box_width; - target_box_height = exp(prior_box_var_data[prior_var_offset + 3] * - target_box_data[idx * len + 3]) * - prior_box_height; - target_box_center_x = prior_box_var_data[prior_var_offset] * - target_box_data[idx * len] * prior_box_width + - prior_box_center_x; - target_box_center_y = prior_box_var_data[prior_var_offset + 1] * - target_box_data[idx * len + 1] * - prior_box_height + - prior_box_center_y; + box_var_x = prior_box_var_data[prior_var_offset]; + box_var_y = prior_box_var_data[prior_var_offset + 1]; + box_var_w = prior_box_var_data[prior_var_offset + 2]; + box_var_h = prior_box_var_data[prior_var_offset + 3]; } else if (var_size == 4) { - target_box_width = - exp(static_cast(variance[2]) * target_box_data[idx * len + 2]) * - prior_box_width; - target_box_height = - exp(static_cast(variance[3]) * target_box_data[idx * len + 3]) * - prior_box_height; - target_box_center_x = static_cast(variance[0]) * - target_box_data[idx * len] * prior_box_width + - prior_box_center_x; - target_box_center_y = static_cast(variance[1]) * - target_box_data[idx * len + 1] * - prior_box_height + - prior_box_center_y; - } else { - target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width; - target_box_height = - exp(target_box_data[idx * len + 3]) * prior_box_height; - target_box_center_x = - target_box_data[idx * len] * prior_box_width + prior_box_center_x; - target_box_center_y = target_box_data[idx * len + 1] * prior_box_height + - prior_box_center_y; + box_var_x = static_cast(variance[0]); + box_var_y = static_cast(variance[1]); + box_var_w = static_cast(variance[2]); + box_var_h = static_cast(variance[3]); } + target_box_width = + exp(box_var_w * target_box_data[idx * len + 2]) * prior_box_width; + target_box_height = + exp(box_var_h * target_box_data[idx * len + 3]) * prior_box_height; + target_box_center_x = + box_var_x * target_box_data[idx * len] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[idx * len + 1] * prior_box_height + + prior_box_center_y; output[idx * len] = target_box_center_x - target_box_width / 2; output[idx * len + 1] = target_box_center_y - target_box_height / 2; @@ -177,9 +164,8 @@ class BoxCoderCUDAKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } - const int var_size = static_cast(variance.size()); - thrust::device_vector dev_variance(variance.begin(), variance.end()); - const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data()); + const int var_size = static_cast(variance.size()); + auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); int axis = context.Attr("axis"); @@ -194,6 +180,16 @@ class BoxCoderCUDAKernel : public framework::OpKernel { int grid = (row * col + block - 1) / block; auto& device_ctx = context.cuda_device_context(); + auto& allocator = + platform::DeviceTemporaryAllocator::Instance().Get(device_ctx); + int bytes = var_size * sizeof(float); + auto dev_var = allocator.Allocate(bytes); + float* dev_var_data = reinterpret_cast(dev_var->ptr()); + auto cplace = platform::CPUPlace(); + const auto gplace = boost::get(context.GetPlace()); + memory::Copy(gplace, dev_var_data, cplace, &variance[0], bytes, + device_ctx.stream()); + output_box->mutable_data({row, col, len}, context.GetPlace()); T* output = output_box->data(); diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index b61cff1b1d..a0b1faf7bd 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -133,6 +133,8 @@ class BoxCoderKernel : public framework::OpKernel { T target_box_center_x = 0, target_box_center_y = 0; T target_box_width = 0, target_box_height = 0; + T box_var_x = T(1), box_var_y = T(1); + T box_var_w = T(1), box_var_h = T(1); if (prior_box_var) { int prior_var_offset = 0; if (prior_box_var->dims().size() == 2) { @@ -141,44 +143,26 @@ class BoxCoderKernel : public framework::OpKernel { else if (axis == 1) prior_var_offset = i * len; } - target_box_center_x = prior_box_var_data[prior_var_offset] * - target_box_data[offset] * prior_box_width + - prior_box_center_x; - target_box_center_y = prior_box_var_data[prior_var_offset + 1] * - target_box_data[offset + 1] * - prior_box_height + - prior_box_center_y; - target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] * - target_box_data[offset + 2]) * - prior_box_width; - target_box_height = - std::exp(prior_box_var_data[prior_var_offset + 3] * - target_box_data[offset + 3]) * - prior_box_height; + box_var_x = prior_box_var_data[prior_var_offset]; + box_var_y = prior_box_var_data[prior_var_offset + 1]; + box_var_w = prior_box_var_data[prior_var_offset + 2]; + box_var_h = prior_box_var_data[prior_var_offset + 3]; } else if (!(variance.empty())) { - target_box_center_x = static_cast(variance[0]) * - target_box_data[offset] * prior_box_width + - prior_box_center_x; - target_box_center_y = static_cast(variance[1]) * - target_box_data[offset + 1] * - prior_box_height + - prior_box_center_y; - target_box_width = std::exp(static_cast(variance[2]) * - target_box_data[offset + 2]) * - prior_box_width; - target_box_height = std::exp(static_cast(variance[3]) * - target_box_data[offset + 3]) * - prior_box_height; - } else { - target_box_center_x = - target_box_data[offset] * prior_box_width + prior_box_center_x; - target_box_center_y = target_box_data[offset + 1] * prior_box_height + - prior_box_center_y; - target_box_width = - std::exp(target_box_data[offset + 2]) * prior_box_width; - target_box_height = - std::exp(target_box_data[offset + 3]) * prior_box_height; + box_var_x = static_cast(variance[0]); + box_var_y = static_cast(variance[1]); + box_var_w = static_cast(variance[2]); + box_var_h = static_cast(variance[3]); } + target_box_center_x = + box_var_x * target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[offset + 1] * prior_box_height + + prior_box_center_y; + target_box_width = + std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width; + target_box_height = std::exp(box_var_h * target_box_data[offset + 3]) * + prior_box_height; output[offset] = target_box_center_x - target_box_width / 2; output[offset + 1] = target_box_center_y - target_box_height / 2; diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 2dbcfa31fc..869da58043 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -50,6 +50,19 @@ class TestDetection(unittest.TestCase): self.assertEqual(out.shape[-1], 6) print(str(program)) + def test_box_coder_api(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[4], dtype='float32') + y = layers.data(name='z', shape=[4], dtype='float32', lod_level=1) + bcoder = layers.box_coder( + prior_box=x, + prior_box_var=[0.1, 0.2, 0.1, 0.2], + target_box=y, + code_type='encode_center_size') + self.assertIsNotNone(bcoder) + print(str(program)) + def test_detection_api(self): program = Program() with program_guard(program): @@ -59,7 +72,7 @@ class TestDetection(unittest.TestCase): iou = layers.iou_similarity(x=x, y=y) bcoder = layers.box_coder( prior_box=x, - prior_box_var=[0.2, 0.3, 0.3, 0.2], + prior_box_var=y, target_box=z, code_type='encode_center_size') self.assertIsNotNone(iou) From e7eb08febedc779ea45084b60e5a3c683c0e47c5 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Sun, 27 Jan 2019 23:22:28 -0800 Subject: [PATCH 16/53] fix api.spec test=develop --- paddle/fluid/API.spec | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f4e964d8c2..e58b57ea54 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -505,4 +505,3 @@ paddle.reader.Fake.__init__ ArgSpec(args=['self'], varargs=None, keywords=None, paddle.reader.creator.np_array ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.reader.creator.text_file ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None) paddle.reader.creator.recordio ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)) - From 6961a94e942796b8f32516897faf4fa95156ad66 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Mon, 28 Jan 2019 22:33:37 -0800 Subject: [PATCH 17/53] avoid out_size less than 1 test=develop --- paddle/fluid/operators/interpolate_op.cu | 34 +++++++++++------- paddle/fluid/operators/interpolate_op.h | 36 ++++++++++++------- .../unittests/test_bilinear_interp_op.py | 18 +++++----- .../tests/unittests/test_nearest_interp_op.py | 18 +++++----- 4 files changed, 66 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 1dfd4947c6..f86d2c4ab4 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -220,12 +220,17 @@ class InterpolateOpCUDAKernel : public framework::OpKernel { int in_chw = c * in_hw; int out_chw = c * out_hw; - float ratio_h = (align_corners && out_h > 1) - ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - float ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } if (in_h == out_h && in_w == out_w) { framework::TensorCopy(*input, ctx.GetPlace(), output); @@ -290,12 +295,17 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { int in_chw = c * in_hw; int out_chw = c * out_hw; - float ratio_h = (align_corners && out_h > 1) - ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - float ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } if (in_h == out_h && in_w == out_w) { framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index 1ec0cb5025..acdebf73e0 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -191,12 +191,18 @@ class InterpolateKernel : public framework::OpKernel { return; } - float ratio_h = (align_corners && out_h > 1) - ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - float ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + float ratio_h = 0.f; + float ratio_w = 0.f; + + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } if ("bilinear" == interp_method) { BilinearInterpolation(*input, output, ratio_h, ratio_w, in_h, in_w, n, @@ -244,12 +250,18 @@ class InterpolateGradKernel : public framework::OpKernel { return; } - float ratio_h = (align_corners && out_h > 1) - ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - float ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + float ratio_h = 0.f; + float ratio_w = 0.f; + + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners && out_w > 1) + ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } if ("bilinear" == interp_method) { BilinearInterpolationGrad(*output_grad, input_grad, ratio_h, ratio_w, diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index 2e3de58a3a..f60ed1d79a 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -37,14 +37,16 @@ def bilinear_interp_np(input, batch_size, channel, in_h, in_w = input.shape ratio_h = ratio_w = 0.0 - if (align_corners and out_h > 1): - ratio_h = (in_h - 1.0) / (out_h - 1.0) - else: - ratio_h = 1.0 * in_h / out_h - if (align_corners and out_w > 1): - ratio_w = (in_w - 1.0) / (out_w - 1.0) - else: - ratio_w = 1.0 * in_w / out_w + if out_h > 1: + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if out_w > 1: + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w out = np.zeros((batch_size, channel, out_h, out_w)) diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py index 9984a793ca..5bb2260ef7 100644 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py @@ -36,14 +36,16 @@ def nearest_neighbor_interp_np(X, n, c, in_h, in_w = X.shape ratio_h = ratio_w = 0.0 - if (align_corners and out_h > 1): - ratio_h = (in_h - 1.0) / (out_h - 1.0) - else: - ratio_h = 1.0 * in_h / out_h - if (align_corners and out_w > 1): - ratio_w = (in_w - 1.0) / (out_w - 1.0) - else: - ratio_w = 1.0 * in_w / out_w + if (out_h > 1): + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if (out_w > 1): + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w out = np.zeros((n, c, out_h, out_w)) From bb881199f23427e10bb868694bd362582b53493d Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 29 Jan 2019 06:37:03 +0000 Subject: [PATCH 18/53] test=develop, polish code and fix wrong change in /paddle/fluid/inference/utils/CMakeLists.txt --- paddle/fluid/inference/utils/CMakeLists.txt | 4 ++-- .../paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt index a7b239731b..c43eaf7f98 100644 --- a/paddle/fluid/inference/utils/CMakeLists.txt +++ b/paddle/fluid/inference/utils/CMakeLists.txt @@ -1,4 +1,4 @@ cc_library(benchmark SRCS benchmark.cc DEPS enforce) cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark) -#cc_binary(visualizer SRCS visualizer.cc DEPS analysis -# paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes) +cc_binary(visualizer SRCS visualizer.cc DEPS analysis + paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py index 5877e91f92..afe990e74f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py @@ -268,9 +268,6 @@ class TestImperativePtbRnn(unittest.TestCase): sgd.minimize(dy_loss) for param in ptb_model.parameters(): dy_param_updated[param.name] = param._numpy() - # print("dy_loss is {}".format(dy_loss._numpy())) - # print("last_hidden is {}".format(last_hidden._numpy())) - # print("last_cell is {}".format(last_cell._numpy())) with new_program_scope(): fluid.default_startup_program().random_seed = seed From 909f864a9bff2812bfea39c230ec779bccd54ca5 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Mon, 28 Jan 2019 22:45:11 -0800 Subject: [PATCH 19/53] remove unnecessary flags test=develop --- paddle/fluid/operators/interpolate_op.cu | 10 ++++------ paddle/fluid/operators/interpolate_op.h | 10 ++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index f86d2c4ab4..b887878ea2 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -227,9 +227,8 @@ class InterpolateOpCUDAKernel : public framework::OpKernel { : static_cast(in_h) / out_h; } if (out_w > 1) { - ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; } if (in_h == out_h && in_w == out_w) { @@ -302,9 +301,8 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { : static_cast(in_h) / out_h; } if (out_w > 1) { - ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; } if (in_h == out_h && in_w == out_w) { diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index acdebf73e0..c631ad1dd1 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -199,9 +199,8 @@ class InterpolateKernel : public framework::OpKernel { : static_cast(in_h) / out_h; } if (out_w > 1) { - ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; } if ("bilinear" == interp_method) { @@ -258,9 +257,8 @@ class InterpolateGradKernel : public framework::OpKernel { : static_cast(in_h) / out_h; } if (out_w > 1) { - ratio_w = (align_corners && out_w > 1) - ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; } if ("bilinear" == interp_method) { From 192d293854b93d86bbb27ed37af199dd6e4ee1c6 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 6 Dec 2018 19:53:41 +0800 Subject: [PATCH 20/53] use stable Sigmoid Cross Entropy implement. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 4 + paddle/fluid/operators/yolov3_loss_op.h | 283 ++++++++++-------- python/paddle/fluid/layers/detection.py | 3 + python/paddle/fluid/tests/test_detection.py | 2 +- .../tests/unittests/test_yolov3_loss_op.py | 90 +++--- 5 files changed, 208 insertions(+), 174 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 60508f7ab8..66d618de59 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -99,6 +99,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("anchors", "The anchor width and height, " "it will be parsed pair by pair."); + AddAttr("input_size", + "The input size of YOLOv3 net, " + "generally this is set as 320, 416 or 608.") + .SetDefault(406); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); AddAttr("loss_weight_xy", "The weight of x, y location loss.") diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 0bb285722d..fac06b4204 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -33,87 +33,91 @@ static inline bool isZero(T x) { } template -static inline T sigmoid(T x) { - return 1.0 / (exp(-1.0 * x) + 1.0); -} +static inline T CalcMSEWithWeight(const Tensor& x, const Tensor& y, + const Tensor& weight, const T mf) { + int numel = static_cast(x.numel()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); -template -static inline T CalcMaskPointNum(const Tensor& mask) { - auto mask_t = EigenVector::Flatten(mask); - T count = 0.0; - for (int i = 0; i < mask_t.dimensions()[0]; i++) { - if (mask_t(i)) { - count += 1.0; - } + T error_sum = 0.0; + for (int i = 0; i < numel; i++) { + T xi = x_data[i]; + T yi = y_data[i]; + T weighti = weight_data[i]; + error_sum += pow(yi - xi, 2) * weighti; } - return count; + + return error_sum / mf; } template -static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y, - const Tensor& mask) { - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - T error_sum = 0.0; - T points = 0.0; - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - error_sum += pow(x_t(i) - y_t(i), 2); - points += 1; - } +static void CalcMSEGradWithWeight(Tensor* grad, const Tensor& x, + const Tensor& y, const Tensor& weight, + const T mf) { + int numel = static_cast(grad->numel()); + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < numel; i++) { + grad_data[i] = 2.0 * weight_data[i] * (x_data[i] - y_data[i]) / mf; } - return (error_sum / points); } template -static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y, - const Tensor& mask, T mf) { - auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf; - } +struct SigmoidCrossEntropyForward { + T operator()(const T& x, const T& label) const { + T term1 = (x > 0) ? x : 0; + T term2 = x * label; + T term3 = std::log(static_cast(1.0) + std::exp(-(std::abs(x)))); + return term1 - term2 + term3; } -} +}; template -static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y, - const Tensor& mask) { - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); +struct SigmoidCrossEntropyBackward { + T operator()(const T& x, const T& label) const { + T sigmoid_x = + static_cast(1.0) / (static_cast(1.0) + std::exp(-1.0 * x)); + return sigmoid_x - label; + } +}; - T error_sum = 0.0; - T points = 0.0; - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - error_sum += - -1.0 * (y_t(i) * log(x_t(i)) + (1.0 - y_t(i)) * log(1.0 - x_t(i))); - points += 1; - } +template +static inline T CalcSCEWithWeight(const Tensor& x, const Tensor& labels, + const Tensor& weight, const T mf) { + int numel = x.numel(); + const T* x_data = x.data(); + const T* labels_data = labels.data(); + const T* weight_data = weight.data(); + + T loss = 0.0; + for (int i = 0; i < numel; i++) { + T xi = x_data[i]; + T labeli = labels_data[i]; + T weighti = weight_data[i]; + loss += ((xi > 0.0 ? xi : 0.0) - xi * labeli + + std::log(1.0 + std::exp(-1.0 * std::abs(xi)))) * + weighti; } - return (error_sum / points); + return loss / mf; } template -static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x, - const Tensor& y, const Tensor& mask, - T mf) { - auto grad_t = EigenVector::Flatten(*grad).setConstant(0.0); - auto x_t = EigenVector::Flatten(x); - auto y_t = EigenVector::Flatten(y); - auto mask_t = EigenVector::Flatten(mask); - - for (int i = 0; i < x_t.dimensions()[0]; i++) { - if (mask_t(i)) { - grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf; - } +static inline void CalcSCEGradWithWeight(Tensor* grad, const Tensor& x, + const Tensor& labels, + const Tensor& weight, const T mf) { + int numel = grad->numel(); + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* labels_data = labels.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < numel; i++) { + grad_data[i] = (1.0 / (1.0 + std::exp(-1.0 * x_data[i])) - labels_data[i]) * + weight_data[i] / mf; } } @@ -139,21 +143,20 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_conf, for (int an_idx = 0; an_idx < anchor_num; an_idx++) { for (int j = 0; j < h; j++) { for (int k = 0; k < w; k++) { - pred_x_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx, j, k)); + pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, k); pred_y_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k)); + input_t(i, box_attr_num * an_idx + 1, j, k); pred_w_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 2, j, k); pred_h_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx + 3, j, k); pred_conf_t(i, an_idx, j, k) = - sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k)); + input_t(i, box_attr_num * an_idx + 4, j, k); for (int c = 0; c < class_num; c++) { pred_class_t(i, an_idx, j, k, c) = - sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k)); + input_t(i, box_attr_num * an_idx + 5 + c, j, k); } } } @@ -188,21 +191,22 @@ static T CalcBoxIoU(std::vector box1, std::vector box2) { template static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, const float ignore_thresh, std::vector anchors, - const int grid_size, Tensor* obj_mask, - Tensor* noobj_mask, Tensor* tx, Tensor* ty, - Tensor* tw, Tensor* th, Tensor* tconf, - Tensor* tclass) { + const int input_size, const int grid_size, + Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, + Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight, + Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; const int b = gt_box.dims()[1]; const int anchor_num = anchors.size() / 2; auto gt_box_t = EigenTensor::From(gt_box); auto gt_label_t = EigenTensor::From(gt_label); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); - auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); + auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); auto tx_t = EigenTensor::From(*tx).setConstant(0.0); auto ty_t = EigenTensor::From(*ty).setConstant(0.0); auto tw_t = EigenTensor::From(*tw).setConstant(0.0); auto th_t = EigenTensor::From(*th).setConstant(0.0); + auto tweight_t = EigenTensor::From(*tweight).setConstant(0.0); auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); @@ -216,8 +220,8 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, int cur_label = gt_label_t(i, j); T gx = gt_box_t(i, j, 0) * grid_size; T gy = gt_box_t(i, j, 1) * grid_size; - T gw = gt_box_t(i, j, 2) * grid_size; - T gh = gt_box_t(i, j, 3) * grid_size; + T gw = gt_box_t(i, j, 2) * input_size; + T gh = gt_box_t(i, j, 3) * input_size; int gi = static_cast(gx); int gj = static_cast(gy); @@ -234,15 +238,17 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - noobj_mask_t(i, an_idx, gj, gi) = 0; + noobj_mask_t(i, an_idx, gj, gi) = static_cast(0.0); } } - obj_mask_t(i, best_an_index, gj, gi) = 1; - noobj_mask_t(i, best_an_index, gj, gi) = 0; + obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); + noobj_mask_t(i, best_an_index, gj, gi) = static_cast(0.0); tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); + tweight_t(i, best_an_index, gj, gi) = + 2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3); tclass_t(i, best_an_index, gj, gi, cur_label) = 1; tconf_t(i, best_an_index, gj, gi) = 1; } @@ -295,27 +301,22 @@ static void AddAllGradToInputGrad( for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { grad_t(i, j * attr_num, k, l) = - grad_x_t(i, j, k, l) * pred_x_t(i, j, k, l) * - (1.0 - pred_x_t(i, j, k, l)) * loss * loss_weight_xy; + grad_x_t(i, j, k, l) * loss * loss_weight_xy; grad_t(i, j * attr_num + 1, k, l) = - grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) * - (1.0 - pred_y_t(i, j, k, l)) * loss * loss_weight_xy; + grad_y_t(i, j, k, l) * loss * loss_weight_xy; grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss * loss_weight_wh; grad_t(i, j * attr_num + 4, k, l) = - grad_conf_target_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * loss_weight_conf_target; + grad_conf_target_t(i, j, k, l) * loss * loss_weight_conf_target; grad_t(i, j * attr_num + 4, k, l) += - grad_conf_notarget_t(i, j, k, l) * pred_conf_t(i, j, k, l) * - (1.0 - pred_conf_t(i, j, k, l)) * loss * + grad_conf_notarget_t(i, j, k, l) * loss * loss_weight_conf_notarget; for (int c = 0; c < class_num; c++) { grad_t(i, j * attr_num + 5 + c, k, l) = - grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) * - (1.0 - pred_class_t(i, j, k, l, c)) * loss * loss_weight_class; + grad_class_t(i, j, k, l, c) * loss * loss_weight_class; } } } @@ -333,6 +334,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto* loss = ctx.Output("Loss"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); + int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); @@ -358,30 +360,46 @@ class Yolov3LossKernel : public framework::OpKernel { &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; - Tensor tx, ty, tw, th, tconf, tclass; - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + Tensor tx, ty, tw, th, tweight, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, - &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, + h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + &tconf, &tclass); + + Tensor obj_weight; + obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + auto obj_weight_t = EigenTensor::From(obj_weight); + auto obj_mask_t = EigenTensor::From(obj_mask); + auto tweight_t = EigenTensor::From(tweight); + obj_weight_t = obj_mask_t * tweight_t; Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); - - T loss_x = CalcMSEWithMask(pred_x, tx, obj_mask); - T loss_y = CalcMSEWithMask(pred_y, ty, obj_mask); - T loss_w = CalcMSEWithMask(pred_w, tw, obj_mask); - T loss_h = CalcMSEWithMask(pred_h, th, obj_mask); - T loss_conf_target = CalcBCEWithMask(pred_conf, tconf, obj_mask); - T loss_conf_notarget = CalcBCEWithMask(pred_conf, tconf, noobj_mask); - T loss_class = CalcBCEWithMask(pred_class, tclass, obj_mask_expand); + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); + + T box_f = static_cast(an_num * h * w); + T class_f = static_cast(an_num * h * w * class_num); + T loss_x = CalcSCEWithWeight(pred_x, tx, obj_weight, box_f); + T loss_y = CalcSCEWithWeight(pred_y, ty, obj_weight, box_f); + T loss_w = CalcMSEWithWeight(pred_w, tw, obj_weight, box_f); + T loss_h = CalcMSEWithWeight(pred_h, th, obj_weight, box_f); + T loss_conf_target = + CalcSCEWithWeight(pred_conf, tconf, obj_mask, box_f); + T loss_conf_notarget = + CalcSCEWithWeight(pred_conf, tconf, noobj_mask, box_f); + T loss_class = + CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, class_f); auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); loss_data[0] = loss_weight_xy * (loss_x + loss_y) + @@ -405,6 +423,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Loss")); const T loss = output_grad->data()[0]; + int input_size = ctx.Attr("input_size"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); @@ -430,22 +449,33 @@ class Yolov3LossGradKernel : public framework::OpKernel { &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; - Tensor tx, ty, tw, th, tconf, tclass; - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + Tensor tx, ty, tw, th, tweight, tconf, tclass; + obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, h, &obj_mask, - &noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, + h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + &tconf, &tclass); + + Tensor obj_weight; + obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + auto obj_weight_t = EigenTensor::From(obj_weight); + auto obj_mask_t = EigenTensor::From(obj_mask); + auto tweight_t = EigenTensor::From(tweight); + obj_weight_t = obj_mask_t * tweight_t; Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask); + obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, + ctx.GetPlace()); + auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); + obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) + .broadcast(Array5(1, 1, 1, 1, class_num)); Tensor grad_x, grad_y, grad_w, grad_h; Tensor grad_conf_target, grad_conf_notarget, grad_class; @@ -456,19 +486,18 @@ class Yolov3LossGradKernel : public framework::OpKernel { grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - T obj_mf = CalcMaskPointNum(obj_mask); - T noobj_mf = CalcMaskPointNum(noobj_mask); - T obj_expand_mf = CalcMaskPointNum(obj_mask_expand); - CalcMSEGradWithMask(&grad_x, pred_x, tx, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_y, pred_y, ty, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_w, pred_w, tw, obj_mask, obj_mf); - CalcMSEGradWithMask(&grad_h, pred_h, th, obj_mask, obj_mf); - CalcBCEGradWithMask(&grad_conf_target, pred_conf, tconf, obj_mask, - obj_mf); - CalcBCEGradWithMask(&grad_conf_notarget, pred_conf, tconf, noobj_mask, - noobj_mf); - CalcBCEGradWithMask(&grad_class, pred_class, tclass, obj_mask_expand, - obj_expand_mf); + T box_f = static_cast(an_num * h * w); + T class_f = static_cast(an_num * h * w * class_num); + CalcSCEGradWithWeight(&grad_x, pred_x, tx, obj_weight, box_f); + CalcSCEGradWithWeight(&grad_y, pred_y, ty, obj_weight, box_f); + CalcMSEGradWithWeight(&grad_w, pred_w, tw, obj_weight, box_f); + CalcMSEGradWithWeight(&grad_h, pred_h, th, obj_weight, box_f); + CalcSCEGradWithWeight(&grad_conf_target, pred_conf, tconf, obj_mask, + box_f); + CalcSCEGradWithWeight(&grad_conf_notarget, pred_conf, tconf, noobj_mask, + box_f); + CalcSCEGradWithWeight(&grad_class, pred_class, tclass, obj_mask_expand, + class_f); input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); AddAllGradToInputGrad( diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 7cf575d253..5fb4588e0b 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -415,6 +415,7 @@ def yolov3_loss(x, anchors, class_num, ignore_thresh, + input_size, loss_weight_xy=None, loss_weight_wh=None, loss_weight_conf_target=None, @@ -436,6 +437,7 @@ def yolov3_loss(x, anchors (list|tuple): ${anchors_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} + input_size (int): ${input_size_comment} loss_weight_xy (float|None): ${loss_weight_xy_comment} loss_weight_wh (float|None): ${loss_weight_wh_comment} loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} @@ -490,6 +492,7 @@ def yolov3_loss(x, "anchors": anchors, "class_num": class_num, "ignore_thresh": ignore_thresh, + "input_size": input_size, } if loss_weight_xy is not None and isinstance(loss_weight_xy, float): diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 8723d9842a..7d75562900 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -464,7 +464,7 @@ class TestYoloDetection(unittest.TestCase): gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, - 0.5) + 0.7, 416) self.assertIsNotNone(loss) diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 544fe4b4f8..07e7155bbf 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -16,31 +16,22 @@ from __future__ import division import unittest import numpy as np +from scipy.special import logit +from scipy.special import expit from op_test import OpTest from paddle.fluid import core -def sigmoid(x): - return 1.0 / (1.0 + np.exp(-1.0 * x)) +def mse(x, y, weight, num): + return ((y - x)**2 * weight).sum() / num -def mse(x, y, num): - return ((y - x)**2).sum() / num - - -def bce(x, y, mask): - x = x.reshape((-1)) - y = y.reshape((-1)) - mask = mask.reshape((-1)) - - error_sum = 0.0 - count = 0 - for i in range(x.shape[0]): - if mask[i] > 0: - error_sum += y[i] * np.log(x[i]) + (1 - y[i]) * np.log(1 - x[i]) - count += 1 - return error_sum / (-1.0 * count) +def sce(x, label, weight, num): + sigmoid_x = expit(x) + term1 = label * np.log(sigmoid_x) + term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) + return ((-term1 - term2) * weight).sum() / num def box_iou(box1, box2): @@ -66,11 +57,12 @@ def box_iou(box1, box2): return inter_area / (b1_area + b2_area + inter_area) -def build_target(gtboxs, gtlabel, attrs, grid_size): - n, b, _ = gtboxs.shape +def build_target(gtboxes, gtlabel, attrs, grid_size): + n, b, _ = gtboxes.shape ignore_thresh = attrs["ignore_thresh"] anchors = attrs["anchors"] class_num = attrs["class_num"] + input_size = attrs["input_size"] an_num = len(anchors) // 2 obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') @@ -78,20 +70,21 @@ def build_target(gtboxs, gtlabel, attrs, grid_size): ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') + tweight = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tcls = np.zeros( (n, an_num, grid_size, grid_size, class_num)).astype('float32') for i in range(n): for j in range(b): - if gtboxs[i, j, :].sum() == 0: + if gtboxes[i, j, :].sum() == 0: continue gt_label = gtlabel[i, j] - gx = gtboxs[i, j, 0] * grid_size - gy = gtboxs[i, j, 1] * grid_size - gw = gtboxs[i, j, 2] * grid_size - gh = gtboxs[i, j, 3] * grid_size + gx = gtboxes[i, j, 0] * grid_size + gy = gtboxes[i, j, 1] * grid_size + gw = gtboxes[i, j, 2] * input_size + gh = gtboxes[i, j, 3] * input_size gi = int(gx) gj = int(gy) @@ -115,10 +108,12 @@ def build_target(gtboxs, gtlabel, attrs, grid_size): best_an_index]) th[i, best_an_index, gj, gi] = np.log( gh / anchors[2 * best_an_index + 1]) + tweight[i, best_an_index, gj, gi] = 2.0 - gtboxes[ + i, j, 2] * gtboxes[i, j, 3] tconf[i, best_an_index, gj, gi] = 1 tcls[i, best_an_index, gj, gi, gt_label] = 1 - return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask) + return (tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask) def YoloV3Loss(x, gtbox, gtlabel, attrs): @@ -126,27 +121,28 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): an_num = len(attrs['anchors']) // 2 class_num = attrs["class_num"] x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) - pred_x = sigmoid(x[:, :, :, :, 0]) - pred_y = sigmoid(x[:, :, :, :, 1]) + pred_x = x[:, :, :, :, 0] + pred_y = x[:, :, :, :, 1] pred_w = x[:, :, :, :, 2] pred_h = x[:, :, :, :, 3] - pred_conf = sigmoid(x[:, :, :, :, 4]) - pred_cls = sigmoid(x[:, :, :, :, 5:]) + pred_conf = x[:, :, :, :, 4] + pred_cls = x[:, :, :, :, 5:] - tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target( + tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( gtbox, gtlabel, attrs, x.shape[2]) + obj_weight = obj_mask * tweight obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) - loss_x = mse(pred_x * obj_mask, tx * obj_mask, obj_mask.sum()) - loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum()) - loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum()) - loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum()) - loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask) - loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask, - noobj_mask) - loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand, - obj_mask_expand) + box_f = an_num * h * w + class_f = an_num * h * w * class_num + loss_x = sce(pred_x, tx, obj_weight, box_f) + loss_y = sce(pred_y, ty, obj_weight, box_f) + loss_w = mse(pred_w, tw, obj_weight, box_f) + loss_h = mse(pred_h, th, obj_weight, box_f) + loss_conf_target = sce(pred_conf, tconf, obj_mask, box_f) + loss_conf_notarget = sce(pred_conf, tconf, noobj_mask, box_f) + loss_class = sce(pred_cls, tcls, obj_mask_expand, class_f) return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + attrs['loss_weight_wh'] * (loss_w + loss_h) \ @@ -164,7 +160,7 @@ class TestYolov3LossOp(OpTest): self.loss_weight_class = 1.0 self.initTestCase() self.op_type = 'yolov3_loss' - x = np.random.random(size=self.x_shape).astype('float32') + x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]).astype('int32') @@ -173,6 +169,7 @@ class TestYolov3LossOp(OpTest): "anchors": self.anchors, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, + "input_size": self.input_size, "loss_weight_xy": self.loss_weight_xy, "loss_weight_wh": self.loss_weight_wh, "loss_weight_conf_target": self.loss_weight_conf_target, @@ -196,18 +193,19 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.06) + max_relative_error=0.3) def initTestCase(self): self.anchors = [10, 13, 12, 12] self.class_num = 10 - self.ignore_thresh = 0.5 + self.ignore_thresh = 0.7 + self.input_size = 416 self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) self.gtbox_shape = (5, 10, 4) - self.loss_weight_xy = 2.5 + self.loss_weight_xy = 1.4 self.loss_weight_wh = 0.8 - self.loss_weight_conf_target = 1.5 - self.loss_weight_conf_notarget = 0.5 + self.loss_weight_conf_target = 1.1 + self.loss_weight_conf_notarget = 0.9 self.loss_weight_class = 1.2 From 3841983aa01dbb633e1d40b84f046ddfbf41beb8 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 7 Dec 2018 11:44:50 +0800 Subject: [PATCH 21/53] fix division error in mean process. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 4 +- paddle/fluid/operators/yolov3_loss_op.h | 263 ++++++++---------- .../paddle/fluid/tests/unittests/op_test.py | 2 + .../tests/unittests/test_yolov3_loss_op.py | 69 +++-- 4 files changed, 166 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 66d618de59..c76767dfdd 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -57,7 +57,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); - std::vector dim_out({1}); + std::vector dim_out({dim_x[0]}); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); } @@ -93,7 +93,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "box class id."); AddOutput("Loss", "The output yolov3 loss tensor, " - "This is a 1-D tensor with shape of [1]"); + "This is a 1-D tensor with shape of [N]"); AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index fac06b4204..837ea15601 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -33,99 +33,102 @@ static inline bool isZero(T x) { } template -static inline T CalcMSEWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, const T mf) { - int numel = static_cast(x.numel()); +static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y, + const Tensor& weight, const T loss_weight, + T* loss) { + int n = x.dims()[0]; + int stride = x.numel() / n; const T* x_data = x.data(); const T* y_data = y.data(); const T* weight_data = weight.data(); - T error_sum = 0.0; - for (int i = 0; i < numel; i++) { - T xi = x_data[i]; - T yi = y_data[i]; - T weighti = weight_data[i]; - error_sum += pow(yi - xi, 2) * weighti; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + loss[i] += pow(y_data[j] - x_data[j], 2) * weight_data[j] * loss_weight; + } + x_data += stride; + y_data += stride; + weight_data += stride; } - - return error_sum / mf; } template -static void CalcMSEGradWithWeight(Tensor* grad, const Tensor& x, - const Tensor& y, const Tensor& weight, - const T mf) { - int numel = static_cast(grad->numel()); +static void CalcMSEGradWithWeight(const T* loss_grad, Tensor* grad, + const Tensor& x, const Tensor& y, + const Tensor& weight) { + int n = x.dims()[0]; + int stride = x.numel() / n; T* grad_data = grad->data(); const T* x_data = x.data(); const T* y_data = y.data(); const T* weight_data = weight.data(); - for (int i = 0; i < numel; i++) { - grad_data[i] = 2.0 * weight_data[i] * (x_data[i] - y_data[i]) / mf; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + grad_data[j] = + 2.0 * weight_data[j] * (x_data[j] - y_data[j]) * loss_grad[i]; + } + grad_data += stride; + x_data += stride; + y_data += stride; + weight_data += stride; } } template -struct SigmoidCrossEntropyForward { - T operator()(const T& x, const T& label) const { - T term1 = (x > 0) ? x : 0; - T term2 = x * label; - T term3 = std::log(static_cast(1.0) + std::exp(-(std::abs(x)))); - return term1 - term2 + term3; - } -}; - -template -struct SigmoidCrossEntropyBackward { - T operator()(const T& x, const T& label) const { - T sigmoid_x = - static_cast(1.0) / (static_cast(1.0) + std::exp(-1.0 * x)); - return sigmoid_x - label; - } -}; - -template -static inline T CalcSCEWithWeight(const Tensor& x, const Tensor& labels, - const Tensor& weight, const T mf) { - int numel = x.numel(); +static inline void CalcSCEWithWeight(const Tensor& x, const Tensor& label, + const Tensor& weight, const T loss_weight, + T* loss) { + int n = x.dims()[0]; + int stride = x.numel() / n; const T* x_data = x.data(); - const T* labels_data = labels.data(); + const T* label_data = label.data(); const T* weight_data = weight.data(); - T loss = 0.0; - for (int i = 0; i < numel; i++) { - T xi = x_data[i]; - T labeli = labels_data[i]; - T weighti = weight_data[i]; - loss += ((xi > 0.0 ? xi : 0.0) - xi * labeli + - std::log(1.0 + std::exp(-1.0 * std::abs(xi)))) * - weighti; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + T term1 = (x_data[j] > 0) ? x_data[j] : 0; + T term2 = x_data[j] * label_data[j]; + T term3 = std::log(1.0 + std::exp(-std::abs(x_data[j]))); + loss[i] += (term1 - term2 + term3) * weight_data[j] * loss_weight; + } + x_data += stride; + label_data += stride; + weight_data += stride; } - return loss / mf; } template -static inline void CalcSCEGradWithWeight(Tensor* grad, const Tensor& x, - const Tensor& labels, - const Tensor& weight, const T mf) { - int numel = grad->numel(); +static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, + const Tensor& x, const Tensor& label, + const Tensor& weight) { + int n = x.dims()[0]; + int stride = x.numel() / n; T* grad_data = grad->data(); const T* x_data = x.data(); - const T* labels_data = labels.data(); + const T* label_data = label.data(); const T* weight_data = weight.data(); - for (int i = 0; i < numel; i++) { - grad_data[i] = (1.0 / (1.0 + std::exp(-1.0 * x_data[i])) - labels_data[i]) * - weight_data[i] / mf; + // LOG(ERROR) << "SCE grad start"; + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * + weight_data[j] * loss_grad[i]; + // if (j == 18) LOG(ERROR) << x_data[j] << " " << label_data[j] << " " << + // weight_data[j] << " " << loss_grad[i]; + } + grad_data += stride; + x_data += stride; + label_data += stride; + weight_data += stride; } } template -static void CalcPredResult(const Tensor& input, Tensor* pred_conf, - Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, - Tensor* pred_w, Tensor* pred_h, const int anchor_num, - const int class_num) { +static void SplitPredResult(const Tensor& input, Tensor* pred_conf, + Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, + Tensor* pred_w, Tensor* pred_h, + const int anchor_num, const int class_num) { const int n = input.dims()[0]; const int h = input.dims()[2]; const int w = input.dims()[3]; @@ -255,39 +258,20 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, } } -static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand, - const Tensor& obj_mask) { - const int n = obj_mask_expand->dims()[0]; - const int an_num = obj_mask_expand->dims()[1]; - const int h = obj_mask_expand->dims()[2]; - const int w = obj_mask_expand->dims()[3]; - const int class_num = obj_mask_expand->dims()[4]; - auto obj_mask_expand_t = EigenTensor::From(*obj_mask_expand); - auto obj_mask_t = EigenTensor::From(obj_mask); - - obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) - .broadcast(Array5(1, 1, 1, 1, class_num)); -} - template static void AddAllGradToInputGrad( - Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y, - const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x, - const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h, - const Tensor& grad_conf_target, const Tensor& grad_conf_notarget, - const Tensor& grad_class, const int class_num, const float loss_weight_xy, - const float loss_weight_wh, const float loss_weight_conf_target, - const float loss_weight_conf_notarget, const float loss_weight_class) { - const int n = pred_x.dims()[0]; - const int an_num = pred_x.dims()[1]; - const int h = pred_x.dims()[2]; - const int w = pred_x.dims()[3]; + Tensor* grad, const Tensor& grad_x, const Tensor& grad_y, + const Tensor& grad_w, const Tensor& grad_h, const Tensor& grad_conf_target, + const Tensor& grad_conf_notarget, const Tensor& grad_class, + const int class_num, const float loss_weight_xy, const float loss_weight_wh, + const float loss_weight_conf_target, const float loss_weight_conf_notarget, + const float loss_weight_class) { + const int n = grad_x.dims()[0]; + const int an_num = grad_x.dims()[1]; + const int h = grad_x.dims()[2]; + const int w = grad_x.dims()[3]; const int attr_num = class_num + 5; auto grad_t = EigenTensor::From(*grad).setConstant(0.0); - auto pred_x_t = EigenTensor::From(pred_x); - auto pred_y_t = EigenTensor::From(pred_y); - auto pred_conf_t = EigenTensor::From(pred_conf); - auto pred_class_t = EigenTensor::From(pred_class); auto grad_x_t = EigenTensor::From(grad_x); auto grad_y_t = EigenTensor::From(grad_y); auto grad_w_t = EigenTensor::From(grad_w); @@ -300,23 +284,21 @@ static void AddAllGradToInputGrad( for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - grad_t(i, j * attr_num, k, l) = - grad_x_t(i, j, k, l) * loss * loss_weight_xy; + grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * loss_weight_xy; grad_t(i, j * attr_num + 1, k, l) = - grad_y_t(i, j, k, l) * loss * loss_weight_xy; + grad_y_t(i, j, k, l) * loss_weight_xy; grad_t(i, j * attr_num + 2, k, l) = - grad_w_t(i, j, k, l) * loss * loss_weight_wh; + grad_w_t(i, j, k, l) * loss_weight_wh; grad_t(i, j * attr_num + 3, k, l) = - grad_h_t(i, j, k, l) * loss * loss_weight_wh; + grad_h_t(i, j, k, l) * loss_weight_wh; grad_t(i, j * attr_num + 4, k, l) = - grad_conf_target_t(i, j, k, l) * loss * loss_weight_conf_target; + grad_conf_target_t(i, j, k, l) * loss_weight_conf_target; grad_t(i, j * attr_num + 4, k, l) += - grad_conf_notarget_t(i, j, k, l) * loss * - loss_weight_conf_notarget; + grad_conf_notarget_t(i, j, k, l) * loss_weight_conf_notarget; for (int c = 0; c < class_num; c++) { grad_t(i, j * attr_num + 5 + c, k, l) = - grad_class_t(i, j, k, l, c) * loss * loss_weight_class; + grad_class_t(i, j, k, l, c) * loss_weight_class; } } } @@ -356,8 +338,8 @@ class Yolov3LossKernel : public framework::OpKernel { pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); + SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; @@ -388,25 +370,24 @@ class Yolov3LossKernel : public framework::OpKernel { obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) .broadcast(Array5(1, 1, 1, 1, class_num)); - T box_f = static_cast(an_num * h * w); - T class_f = static_cast(an_num * h * w * class_num); - T loss_x = CalcSCEWithWeight(pred_x, tx, obj_weight, box_f); - T loss_y = CalcSCEWithWeight(pred_y, ty, obj_weight, box_f); - T loss_w = CalcMSEWithWeight(pred_w, tw, obj_weight, box_f); - T loss_h = CalcMSEWithWeight(pred_h, th, obj_weight, box_f); - T loss_conf_target = - CalcSCEWithWeight(pred_conf, tconf, obj_mask, box_f); - T loss_conf_notarget = - CalcSCEWithWeight(pred_conf, tconf, noobj_mask, box_f); - T loss_class = - CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, class_f); - - auto* loss_data = loss->mutable_data({1}, ctx.GetPlace()); - loss_data[0] = loss_weight_xy * (loss_x + loss_y) + - loss_weight_wh * (loss_w + loss_h) + - loss_weight_conf_target * loss_conf_target + - loss_weight_conf_notarget * loss_conf_notarget + - loss_weight_class * loss_class; + T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); + memset(loss_data, 0, n * sizeof(T)); + CalcSCEWithWeight(pred_x, tx, obj_weight, loss_weight_xy, loss_data); + CalcSCEWithWeight(pred_y, ty, obj_weight, loss_weight_xy, loss_data); + CalcMSEWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); + CalcMSEWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); + CalcSCEWithWeight(pred_conf, tconf, obj_mask, loss_weight_conf_target, + loss_data); + CalcSCEWithWeight(pred_conf, tconf, noobj_mask, + loss_weight_conf_notarget, loss_data); + CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, loss_weight_class, + loss_data); + + // loss_data[0] = (loss_weight_xy * (loss_x + loss_y) + + // loss_weight_wh * (loss_w + loss_h) + + // loss_weight_conf_target * loss_conf_target + + // loss_weight_conf_notarget * loss_conf_notarget + + // loss_weight_class * loss_class) / n; } }; @@ -421,8 +402,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* output_grad = ctx.Input(framework::GradVarName("Loss")); - const T loss = output_grad->data()[0]; + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + const T* loss_grad_data = loss_grad->data(); int input_size = ctx.Attr("input_size"); float loss_weight_xy = ctx.Attr("loss_weight_xy"); float loss_weight_wh = ctx.Attr("loss_weight_wh"); @@ -445,8 +426,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); + SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, + &pred_w, &pred_h, an_num, class_num); Tensor obj_mask, noobj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; @@ -470,6 +451,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto tweight_t = EigenTensor::From(tweight); obj_weight_t = obj_mask_t * tweight_t; + // LOG(ERROR) << obj_mask_t; + Tensor obj_mask_expand; obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); @@ -486,25 +469,23 @@ class Yolov3LossGradKernel : public framework::OpKernel { grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - T box_f = static_cast(an_num * h * w); - T class_f = static_cast(an_num * h * w * class_num); - CalcSCEGradWithWeight(&grad_x, pred_x, tx, obj_weight, box_f); - CalcSCEGradWithWeight(&grad_y, pred_y, ty, obj_weight, box_f); - CalcMSEGradWithWeight(&grad_w, pred_w, tw, obj_weight, box_f); - CalcMSEGradWithWeight(&grad_h, pred_h, th, obj_weight, box_f); - CalcSCEGradWithWeight(&grad_conf_target, pred_conf, tconf, obj_mask, - box_f); - CalcSCEGradWithWeight(&grad_conf_notarget, pred_conf, tconf, noobj_mask, - box_f); - CalcSCEGradWithWeight(&grad_class, pred_class, tclass, obj_mask_expand, - class_f); + CalcSCEGradWithWeight(loss_grad_data, &grad_x, pred_x, tx, obj_weight); + CalcSCEGradWithWeight(loss_grad_data, &grad_y, pred_y, ty, obj_weight); + CalcMSEGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, obj_weight); + CalcMSEGradWithWeight(loss_grad_data, &grad_h, pred_h, th, obj_weight); + CalcSCEGradWithWeight(loss_grad_data, &grad_conf_target, pred_conf, + tconf, obj_mask); + CalcSCEGradWithWeight(loss_grad_data, &grad_conf_notarget, pred_conf, + tconf, noobj_mask); + CalcSCEGradWithWeight(loss_grad_data, &grad_class, pred_class, tclass, + obj_mask_expand); input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - AddAllGradToInputGrad( - input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y, - grad_w, grad_h, grad_conf_target, grad_conf_notarget, grad_class, - class_num, loss_weight_xy, loss_weight_wh, loss_weight_conf_target, - loss_weight_conf_notarget, loss_weight_class); + AddAllGradToInputGrad(input_grad, grad_x, grad_y, grad_w, grad_h, + grad_conf_target, grad_conf_notarget, grad_class, + class_num, loss_weight_xy, loss_weight_wh, + loss_weight_conf_target, loss_weight_conf_notarget, + loss_weight_class); } }; diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 0fe836683b..9cf398f18f 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -470,6 +470,8 @@ class OpTest(unittest.TestCase): ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) + # print(numeric_grads[0][0, 4, :, :]) + # print(analytic_grads[0][0, 4, :, :]) self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, max_relative_error, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 07e7155bbf..26367f213b 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -23,15 +23,23 @@ from op_test import OpTest from paddle.fluid import core -def mse(x, y, weight, num): - return ((y - x)**2 * weight).sum() / num - - -def sce(x, label, weight, num): +def mse(x, y, weight): + n = x.shape[0] + x = x.reshape((n, -1)) + y = y.reshape((n, -1)) + weight = weight.reshape((n, -1)) + return ((y - x)**2 * weight).sum(axis=1) + + +def sce(x, label, weight): + n = x.shape[0] + x = x.reshape((n, -1)) + label = label.reshape((n, -1)) + weight = weight.reshape((n, -1)) sigmoid_x = expit(x) term1 = label * np.log(sigmoid_x) term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) - return ((-term1 - term2) * weight).sum() / num + return ((-term1 - term2) * weight).sum(axis=1) def box_iou(box1, box2): @@ -131,18 +139,24 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( gtbox, gtlabel, attrs, x.shape[2]) + # print("obj_mask: ", obj_mask[0, 0, :, :]) + # print("noobj_mask: ", noobj_mask[0, 0, :, :]) obj_weight = obj_mask * tweight obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) - box_f = an_num * h * w - class_f = an_num * h * w * class_num - loss_x = sce(pred_x, tx, obj_weight, box_f) - loss_y = sce(pred_y, ty, obj_weight, box_f) - loss_w = mse(pred_w, tw, obj_weight, box_f) - loss_h = mse(pred_h, th, obj_weight, box_f) - loss_conf_target = sce(pred_conf, tconf, obj_mask, box_f) - loss_conf_notarget = sce(pred_conf, tconf, noobj_mask, box_f) - loss_class = sce(pred_cls, tcls, obj_mask_expand, class_f) + loss_x = sce(pred_x, tx, obj_weight) + loss_y = sce(pred_y, ty, obj_weight) + loss_w = mse(pred_w, tw, obj_weight) + loss_h = mse(pred_h, th, obj_weight) + loss_conf_target = sce(pred_conf, tconf, obj_mask) + loss_conf_notarget = sce(pred_conf, tconf, noobj_mask) + loss_class = sce(pred_cls, tcls, obj_mask_expand) + + # print("loss_xy: ", loss_x + loss_y) + # print("loss_wh: ", loss_w + loss_h) + # print("loss_conf_target: ", loss_conf_target) + # print("loss_conf_notarget: ", loss_conf_notarget) + # print("loss_class: ", loss_class) return attrs['loss_weight_xy'] * (loss_x + loss_y) \ + attrs['loss_weight_wh'] * (loss_w + loss_h) \ @@ -178,10 +192,7 @@ class TestYolov3LossOp(OpTest): } self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} - self.outputs = { - 'Loss': np.array( - [YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32') - } + self.outputs = {'Loss': YoloV3Loss(x, gtbox, gtlabel, self.attrs)} def test_check_output(self): place = core.CPUPlace() @@ -193,20 +204,20 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.3) + max_relative_error=0.31) def initTestCase(self): - self.anchors = [10, 13, 12, 12] - self.class_num = 10 - self.ignore_thresh = 0.7 + self.anchors = [12, 12] + self.class_num = 5 + self.ignore_thresh = 0.3 self.input_size = 416 - self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7) - self.gtbox_shape = (5, 10, 4) - self.loss_weight_xy = 1.4 + self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 5, 4) + self.loss_weight_xy = 1.2 self.loss_weight_wh = 0.8 - self.loss_weight_conf_target = 1.1 - self.loss_weight_conf_notarget = 0.9 - self.loss_weight_class = 1.2 + self.loss_weight_conf_target = 2.0 + self.loss_weight_conf_notarget = 1.0 + self.loss_weight_class = 1.5 if __name__ == "__main__": From c0fa8d2eec4d6986c4b224a9183207160ea44107 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 10 Dec 2018 20:14:57 +0800 Subject: [PATCH 22/53] use L1Loss for w, h. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 53 +++++++++++++++++-- .../tests/unittests/test_yolov3_loss_op.py | 12 ++++- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 837ea15601..4661747261 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -32,6 +32,49 @@ static inline bool isZero(T x) { return fabs(x) < 1e-6; } +template +static inline void CalcL1LossWithWeight(const Tensor& x, const Tensor& y, + const Tensor& weight, + const T loss_weight, T* loss) { + int n = x.dims()[0]; + int stride = x.numel() / n; + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + loss[i] += fabs(y_data[j] - x_data[j]) * weight_data[j] * loss_weight; + } + x_data += stride; + y_data += stride; + weight_data += stride; + } +} + +template +static void CalcL1LossGradWithWeight(const T* loss_grad, Tensor* grad, + const Tensor& x, const Tensor& y, + const Tensor& weight) { + int n = x.dims()[0]; + int stride = x.numel() / n; + T* grad_data = grad->data(); + const T* x_data = x.data(); + const T* y_data = y.data(); + const T* weight_data = weight.data(); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < stride; j++) { + grad_data[j] = weight_data[j] * loss_grad[i]; + if (x_data[j] < y_data[j]) grad_data[j] *= -1.0; + } + grad_data += stride; + x_data += stride; + y_data += stride; + weight_data += stride; + } +} + template static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y, const Tensor& weight, const T loss_weight, @@ -374,8 +417,8 @@ class Yolov3LossKernel : public framework::OpKernel { memset(loss_data, 0, n * sizeof(T)); CalcSCEWithWeight(pred_x, tx, obj_weight, loss_weight_xy, loss_data); CalcSCEWithWeight(pred_y, ty, obj_weight, loss_weight_xy, loss_data); - CalcMSEWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); - CalcMSEWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); + CalcL1LossWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); + CalcL1LossWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); CalcSCEWithWeight(pred_conf, tconf, obj_mask, loss_weight_conf_target, loss_data); CalcSCEWithWeight(pred_conf, tconf, noobj_mask, @@ -471,8 +514,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); CalcSCEGradWithWeight(loss_grad_data, &grad_x, pred_x, tx, obj_weight); CalcSCEGradWithWeight(loss_grad_data, &grad_y, pred_y, ty, obj_weight); - CalcMSEGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, obj_weight); - CalcMSEGradWithWeight(loss_grad_data, &grad_h, pred_h, th, obj_weight); + CalcL1LossGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, + obj_weight); + CalcL1LossGradWithWeight(loss_grad_data, &grad_h, pred_h, th, + obj_weight); CalcSCEGradWithWeight(loss_grad_data, &grad_conf_target, pred_conf, tconf, obj_mask); CalcSCEGradWithWeight(loss_grad_data, &grad_conf_notarget, pred_conf, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 26367f213b..e218031286 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -23,6 +23,14 @@ from op_test import OpTest from paddle.fluid import core +def l1loss(x, y, weight): + n = x.shape[0] + x = x.reshape((n, -1)) + y = y.reshape((n, -1)) + weight = weight.reshape((n, -1)) + return (np.abs(y - x) * weight).sum(axis=1) + + def mse(x, y, weight): n = x.shape[0] x = x.reshape((n, -1)) @@ -146,8 +154,8 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) loss_x = sce(pred_x, tx, obj_weight) loss_y = sce(pred_y, ty, obj_weight) - loss_w = mse(pred_w, tw, obj_weight) - loss_h = mse(pred_h, th, obj_weight) + loss_w = l1loss(pred_w, tw, obj_weight) + loss_h = l1loss(pred_h, th, obj_weight) loss_conf_target = sce(pred_conf, tconf, obj_mask) loss_conf_notarget = sce(pred_conf, tconf, noobj_mask) loss_class = sce(pred_cls, tcls, obj_mask_expand) From 2fbfef2ec9683ac18903ca8cf7cb69c5389ba3ba Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 13 Dec 2018 19:15:52 +0800 Subject: [PATCH 23/53] fix no box expression. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 4661747261..d0064a8190 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -152,13 +152,10 @@ static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, const T* label_data = label.data(); const T* weight_data = weight.data(); - // LOG(ERROR) << "SCE grad start"; for (int i = 0; i < n; i++) { for (int j = 0; j < stride; j++) { grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * weight_data[j] * loss_grad[i]; - // if (j == 18) LOG(ERROR) << x_data[j] << " " << label_data[j] << " " << - // weight_data[j] << " " << loss_grad[i]; } grad_data += stride; x_data += stride; @@ -258,8 +255,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_box_t(i, j, 0)) && isZero(gt_box_t(i, j, 1)) && - isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + if (isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { continue; } @@ -425,12 +421,6 @@ class Yolov3LossKernel : public framework::OpKernel { loss_weight_conf_notarget, loss_data); CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, loss_weight_class, loss_data); - - // loss_data[0] = (loss_weight_xy * (loss_x + loss_y) + - // loss_weight_wh * (loss_w + loss_h) + - // loss_weight_conf_target * loss_conf_target + - // loss_weight_conf_notarget * loss_conf_notarget + - // loss_weight_class * loss_class) / n; } }; @@ -494,8 +484,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto tweight_t = EigenTensor::From(tweight); obj_weight_t = obj_mask_t * tweight_t; - // LOG(ERROR) << obj_mask_t; - Tensor obj_mask_expand; obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); From 0c4acc83050fb83860884ea02ac241a5ddd6800e Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sun, 16 Dec 2018 17:50:41 +0800 Subject: [PATCH 24/53] imporve yolo loss implement. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 17 +- paddle/fluid/operators/yolov3_loss_op.h | 432 ++++++++++-------- python/paddle/fluid/layers/detection.py | 34 +- .../paddle/fluid/tests/unittests/op_test.py | 2 - .../tests/unittests/test_yolov3_loss_op.py | 49 +- 5 files changed, 267 insertions(+), 267 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index c76767dfdd..3bd0db8b59 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -34,11 +34,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto anchors = ctx->Attrs().Get>("anchors"); + int anchor_num = anchors.size() / 2; auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); - PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num), + PADDLE_ENFORCE_EQ(dim_x[1], anchor_num * (5 + class_num), "Input(X) dim[1] should be equal to (anchor_number * (5 " "+ class_num))."); PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, @@ -105,20 +106,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(406); AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss."); - AddAttr("loss_weight_xy", "The weight of x, y location loss.") - .SetDefault(1.0); - AddAttr("loss_weight_wh", "The weight of w, h location loss.") - .SetDefault(1.0); - AddAttr( - "loss_weight_conf_target", - "The weight of confidence score loss in locations with target object.") - .SetDefault(1.0); - AddAttr("loss_weight_conf_notarget", - "The weight of confidence score loss in locations without " - "target object.") - .SetDefault(1.0); - AddAttr("loss_weight_class", "The weight of classification loss.") - .SetDefault(1.0); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index d0064a8190..5de5b4efc7 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -164,48 +164,50 @@ static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, } } -template -static void SplitPredResult(const Tensor& input, Tensor* pred_conf, - Tensor* pred_class, Tensor* pred_x, Tensor* pred_y, - Tensor* pred_w, Tensor* pred_h, - const int anchor_num, const int class_num) { - const int n = input.dims()[0]; - const int h = input.dims()[2]; - const int w = input.dims()[3]; - const int box_attr_num = 5 + class_num; - - auto input_t = EigenTensor::From(input); - auto pred_conf_t = EigenTensor::From(*pred_conf); - auto pred_class_t = EigenTensor::From(*pred_class); - auto pred_x_t = EigenTensor::From(*pred_x); - auto pred_y_t = EigenTensor::From(*pred_y); - auto pred_w_t = EigenTensor::From(*pred_w); - auto pred_h_t = EigenTensor::From(*pred_h); - - for (int i = 0; i < n; i++) { - for (int an_idx = 0; an_idx < anchor_num; an_idx++) { - for (int j = 0; j < h; j++) { - for (int k = 0; k < w; k++) { - pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, k); - pred_y_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 1, j, k); - pred_w_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 2, j, k); - pred_h_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 3, j, k); - - pred_conf_t(i, an_idx, j, k) = - input_t(i, box_attr_num * an_idx + 4, j, k); - - for (int c = 0; c < class_num; c++) { - pred_class_t(i, an_idx, j, k, c) = - input_t(i, box_attr_num * an_idx + 5 + c, j, k); - } - } - } - } - } -} +// template +// static void SplitPredResult(const Tensor& input, Tensor* pred_conf, +// Tensor* pred_class, Tensor* pred_x, Tensor* +// pred_y, +// Tensor* pred_w, Tensor* pred_h, +// const int anchor_num, const int class_num) { +// const int n = input.dims()[0]; +// const int h = input.dims()[2]; +// const int w = input.dims()[3]; +// const int box_attr_num = 5 + class_num; +// +// auto input_t = EigenTensor::From(input); +// auto pred_conf_t = EigenTensor::From(*pred_conf); +// auto pred_class_t = EigenTensor::From(*pred_class); +// auto pred_x_t = EigenTensor::From(*pred_x); +// auto pred_y_t = EigenTensor::From(*pred_y); +// auto pred_w_t = EigenTensor::From(*pred_w); +// auto pred_h_t = EigenTensor::From(*pred_h); +// +// for (int i = 0; i < n; i++) { +// for (int an_idx = 0; an_idx < anchor_num; an_idx++) { +// for (int j = 0; j < h; j++) { +// for (int k = 0; k < w; k++) { +// pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, +// k); +// pred_y_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 1, j, k); +// pred_w_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 2, j, k); +// pred_h_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 3, j, k); +// +// pred_conf_t(i, an_idx, j, k) = +// input_t(i, box_attr_num * an_idx + 4, j, k); +// +// for (int c = 0; c < class_num; c++) { +// pred_class_t(i, an_idx, j, k, c) = +// input_t(i, box_attr_num * an_idx + 5 + c, j, k); +// } +// } +// } +// } +// } +// } template static T CalcBoxIoU(std::vector box1, std::vector box2) { @@ -235,7 +237,7 @@ template static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, const float ignore_thresh, std::vector anchors, const int input_size, const int grid_size, - Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx, + Tensor* conf_mask, Tensor* obj_mask, Tensor* tx, Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight, Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; @@ -243,8 +245,8 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, const int anchor_num = anchors.size() / 2; auto gt_box_t = EigenTensor::From(gt_box); auto gt_label_t = EigenTensor::From(gt_label); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0); - auto noobj_mask_t = EigenTensor::From(*noobj_mask).setConstant(1); + auto conf_mask_t = EigenTensor::From(*conf_mask).setConstant(1.0); + auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0.0); auto tx_t = EigenTensor::From(*tx).setConstant(0.0); auto ty_t = EigenTensor::From(*ty).setConstant(0.0); auto tw_t = EigenTensor::From(*tw).setConstant(0.0); @@ -280,11 +282,11 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - noobj_mask_t(i, an_idx, gj, gi) = static_cast(0.0); + conf_mask_t(i, an_idx, gj, gi) = static_cast(0.0); } } + conf_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - noobj_mask_t(i, best_an_index, gj, gi) = static_cast(0.0); tx_t(i, best_an_index, gj, gi) = gx - gi; ty_t(i, best_an_index, gj, gi) = gy - gj; tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); @@ -298,53 +300,194 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, } template -static void AddAllGradToInputGrad( - Tensor* grad, const Tensor& grad_x, const Tensor& grad_y, - const Tensor& grad_w, const Tensor& grad_h, const Tensor& grad_conf_target, - const Tensor& grad_conf_notarget, const Tensor& grad_class, - const int class_num, const float loss_weight_xy, const float loss_weight_wh, - const float loss_weight_conf_target, const float loss_weight_conf_notarget, - const float loss_weight_class) { - const int n = grad_x.dims()[0]; - const int an_num = grad_x.dims()[1]; - const int h = grad_x.dims()[2]; - const int w = grad_x.dims()[3]; - const int attr_num = class_num + 5; - auto grad_t = EigenTensor::From(*grad).setConstant(0.0); - auto grad_x_t = EigenTensor::From(grad_x); - auto grad_y_t = EigenTensor::From(grad_y); - auto grad_w_t = EigenTensor::From(grad_w); - auto grad_h_t = EigenTensor::From(grad_h); - auto grad_conf_target_t = EigenTensor::From(grad_conf_target); - auto grad_conf_notarget_t = EigenTensor::From(grad_conf_notarget); - auto grad_class_t = EigenTensor::From(grad_class); +static T SCE(T x, T label) { + return (x > 0 ? x : 0.0) - x * label + std::log(1.0 + std::exp(-std::abs(x))); +} + +template +static T L1Loss(T x, T y) { + return std::abs(y - x); +} + +template +static T SCEGrad(T x, T label) { + return 1.0 / (1.0 + std::exp(-x)) - label; +} + +template +static T L1LossGrad(T x, T y) { + return x > y ? 1.0 : -1.0; +} + +template +static void CalcSCE(T* loss_data, const T* input, const T* target, + const T* weight, const T* mask, const int n, + const int an_num, const int grid_num, const int class_num, + const int num) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < grid_num; k++) { + int sub_idx = k * num; + for (int l = 0; l < num; l++) { + loss_data[i] += SCE(input[l * grid_num + k], target[sub_idx + l]) * + weight[k] * mask[k]; + } + } + input += (class_num + 5) * grid_num; + target += grid_num * num; + weight += grid_num; + mask += grid_num; + } + } +} +template +static void CalcSCEGrad(T* input_grad, const T* loss_grad, const T* input, + const T* target, const T* weight, const T* mask, + const int n, const int an_num, const int grid_num, + const int class_num, const int num) { for (int i = 0; i < n; i++) { for (int j = 0; j < an_num; j++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) * loss_weight_xy; - grad_t(i, j * attr_num + 1, k, l) = - grad_y_t(i, j, k, l) * loss_weight_xy; - grad_t(i, j * attr_num + 2, k, l) = - grad_w_t(i, j, k, l) * loss_weight_wh; - grad_t(i, j * attr_num + 3, k, l) = - grad_h_t(i, j, k, l) * loss_weight_wh; - grad_t(i, j * attr_num + 4, k, l) = - grad_conf_target_t(i, j, k, l) * loss_weight_conf_target; - grad_t(i, j * attr_num + 4, k, l) += - grad_conf_notarget_t(i, j, k, l) * loss_weight_conf_notarget; - - for (int c = 0; c < class_num; c++) { - grad_t(i, j * attr_num + 5 + c, k, l) = - grad_class_t(i, j, k, l, c) * loss_weight_class; - } + for (int k = 0; k < grid_num; k++) { + int sub_idx = k * num; + for (int l = 0; l < num; l++) { + input_grad[l * grid_num + k] = + SCEGrad(input[l * grid_num + k], target[sub_idx + l]) * + weight[k] * mask[k] * loss_grad[i]; } } + input_grad += (class_num + 5) * grid_num; + input += (class_num + 5) * grid_num; + target += grid_num * num; + weight += grid_num; + mask += grid_num; + } + } +} + +template +static void CalcL1Loss(T* loss_data, const T* input, const T* target, + const T* weight, const T* mask, const int n, + const int an_num, const int grid_num, + const int class_num) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < grid_num; k++) { + loss_data[i] += L1Loss(input[k], target[k]) * weight[k] * mask[k]; + } + input += (class_num + 5) * grid_num; + target += grid_num; + weight += grid_num; + mask += grid_num; + } + } +} + +template +static void CalcL1LossGrad(T* input_grad, const T* loss_grad, const T* input, + const T* target, const T* weight, const T* mask, + const int n, const int an_num, const int grid_num, + const int class_num) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < grid_num; k++) { + input_grad[k] = L1LossGrad(input[k], target[k]) * weight[k] * + mask[k] * loss_grad[i]; + } + input_grad += (class_num + 5) * grid_num; + input += (class_num + 5) * grid_num; + target += grid_num; + weight += grid_num; + mask += grid_num; } } } +template +static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, + const Tensor& ty, const Tensor& tw, const Tensor& th, + const Tensor& tweight, const Tensor& tconf, + const Tensor& tclass, const Tensor& conf_mask, + const Tensor& obj_mask) { + const T* input_data = input.data(); + const T* tx_data = tx.data(); + const T* ty_data = ty.data(); + const T* tw_data = tw.data(); + const T* th_data = th.data(); + const T* tweight_data = tweight.data(); + const T* tconf_data = tconf.data(); + const T* tclass_data = tclass.data(); + const T* conf_mask_data = conf_mask.data(); + const T* obj_mask_data = obj_mask.data(); + + const int n = tclass.dims()[0]; + const int an_num = tclass.dims()[1]; + const int h = tclass.dims()[2]; + const int w = tclass.dims()[3]; + const int class_num = tclass.dims()[4]; + const int grid_num = h * w; + + CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, + an_num, grid_num, class_num, 1); + CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num, 1); + CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, + conf_mask_data, n, an_num, grid_num, class_num, 1); + CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, + obj_mask_data, n, an_num, grid_num, class_num, class_num); +} + +template +static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, + const Tensor& input, const Tensor& tx, + const Tensor& ty, const Tensor& tw, + const Tensor& th, const Tensor& tweight, + const Tensor& tconf, const Tensor& tclass, + const Tensor& conf_mask, + const Tensor& obj_mask) { + const T* loss_grad_data = loss_grad.data(); + const T* input_data = input.data(); + const T* tx_data = tx.data(); + const T* ty_data = ty.data(); + const T* tw_data = tw.data(); + const T* th_data = th.data(); + const T* tweight_data = tweight.data(); + const T* tconf_data = tconf.data(); + const T* tclass_data = tclass.data(); + const T* conf_mask_data = conf_mask.data(); + const T* obj_mask_data = obj_mask.data(); + + const int n = tclass.dims()[0]; + const int an_num = tclass.dims()[1]; + const int h = tclass.dims()[2]; + const int w = tclass.dims()[3]; + const int class_num = tclass.dims()[4]; + const int grid_num = h * w; + + CalcSCEGrad(input_grad_data, loss_grad_data, input_data, tx_data, + tweight_data, obj_mask_data, n, an_num, grid_num, class_num, + 1); + CalcSCEGrad(input_grad_data + grid_num, loss_grad_data, + input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, + an_num, grid_num, class_num, 1); + CalcL1LossGrad(input_grad_data + 2 * grid_num, loss_grad_data, + input_data + 2 * grid_num, tw_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcL1LossGrad(input_grad_data + 3 * grid_num, loss_grad_data, + input_data + 3 * grid_num, th_data, tweight_data, + obj_mask_data, n, an_num, grid_num, class_num); + CalcSCEGrad(input_grad_data + 4 * grid_num, loss_grad_data, + input_data + 4 * grid_num, tconf_data, conf_mask_data, + conf_mask_data, n, an_num, grid_num, class_num, 1); + CalcSCEGrad(input_grad_data + 5 * grid_num, loss_grad_data, + input_data + 5 * grid_num, tclass_data, obj_mask_data, + obj_mask_data, n, an_num, grid_num, class_num, class_num); +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -357,33 +500,16 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); - float loss_weight_xy = ctx.Attr("loss_weight_xy"); - float loss_weight_wh = ctx.Attr("loss_weight_wh"); - float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); - float loss_weight_conf_notarget = - ctx.Attr("loss_weight_conf_notarget"); - float loss_weight_class = ctx.Attr("loss_weight_class"); const int n = input->dims()[0]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - Tensor pred_x, pred_y, pred_w, pred_h; - Tensor pred_conf, pred_class; - pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); - - Tensor obj_mask, noobj_mask; + Tensor conf_mask, obj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; + conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); @@ -392,35 +518,13 @@ class Yolov3LossKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); - Tensor obj_weight; - obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - auto obj_weight_t = EigenTensor::From(obj_weight); - auto obj_mask_t = EigenTensor::From(obj_mask); - auto tweight_t = EigenTensor::From(tweight); - obj_weight_t = obj_mask_t * tweight_t; - - Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); - obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) - .broadcast(Array5(1, 1, 1, 1, class_num)); - T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, n * sizeof(T)); - CalcSCEWithWeight(pred_x, tx, obj_weight, loss_weight_xy, loss_data); - CalcSCEWithWeight(pred_y, ty, obj_weight, loss_weight_xy, loss_data); - CalcL1LossWithWeight(pred_w, tw, obj_weight, loss_weight_wh, loss_data); - CalcL1LossWithWeight(pred_h, th, obj_weight, loss_weight_wh, loss_data); - CalcSCEWithWeight(pred_conf, tconf, obj_mask, loss_weight_conf_target, - loss_data); - CalcSCEWithWeight(pred_conf, tconf, noobj_mask, - loss_weight_conf_notarget, loss_data); - CalcSCEWithWeight(pred_class, tclass, obj_mask_expand, loss_weight_class, - loss_data); + CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, tclass, + conf_mask, obj_mask); } }; @@ -436,14 +540,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { float ignore_thresh = ctx.Attr("ignore_thresh"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); - const T* loss_grad_data = loss_grad->data(); int input_size = ctx.Attr("input_size"); - float loss_weight_xy = ctx.Attr("loss_weight_xy"); - float loss_weight_wh = ctx.Attr("loss_weight_wh"); - float loss_weight_conf_target = ctx.Attr("loss_weight_conf_target"); - float loss_weight_conf_notarget = - ctx.Attr("loss_weight_conf_notarget"); - float loss_weight_class = ctx.Attr("loss_weight_class"); const int n = input->dims()[0]; const int c = input->dims()[1]; @@ -451,21 +548,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - Tensor pred_x, pred_y, pred_w, pred_h; - Tensor pred_conf, pred_class; - pred_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_conf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - pred_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - SplitPredResult(*input, &pred_conf, &pred_class, &pred_x, &pred_y, - &pred_w, &pred_h, an_num, class_num); - - Tensor obj_mask, noobj_mask; + Tensor conf_mask, obj_mask; Tensor tx, ty, tw, th, tweight, tconf, tclass; + conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - noobj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); @@ -474,51 +560,13 @@ class Yolov3LossGradKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tweight, + h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); - Tensor obj_weight; - obj_weight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - auto obj_weight_t = EigenTensor::From(obj_weight); - auto obj_mask_t = EigenTensor::From(obj_mask); - auto tweight_t = EigenTensor::From(tweight); - obj_weight_t = obj_mask_t * tweight_t; - - Tensor obj_mask_expand; - obj_mask_expand.mutable_data({n, an_num, h, w, class_num}, - ctx.GetPlace()); - auto obj_mask_expand_t = EigenTensor::From(obj_mask_expand); - obj_mask_expand_t = obj_mask_t.reshape(Array5(n, an_num, h, w, 1)) - .broadcast(Array5(1, 1, 1, 1, class_num)); - - Tensor grad_x, grad_y, grad_w, grad_h; - Tensor grad_conf_target, grad_conf_notarget, grad_class; - grad_x.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_y.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_w.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_h.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_conf_target.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_conf_notarget.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - grad_class.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - CalcSCEGradWithWeight(loss_grad_data, &grad_x, pred_x, tx, obj_weight); - CalcSCEGradWithWeight(loss_grad_data, &grad_y, pred_y, ty, obj_weight); - CalcL1LossGradWithWeight(loss_grad_data, &grad_w, pred_w, tw, - obj_weight); - CalcL1LossGradWithWeight(loss_grad_data, &grad_h, pred_h, th, - obj_weight); - CalcSCEGradWithWeight(loss_grad_data, &grad_conf_target, pred_conf, - tconf, obj_mask); - CalcSCEGradWithWeight(loss_grad_data, &grad_conf_notarget, pred_conf, - tconf, noobj_mask); - CalcSCEGradWithWeight(loss_grad_data, &grad_class, pred_class, tclass, - obj_mask_expand); - - input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - AddAllGradToInputGrad(input_grad, grad_x, grad_y, grad_w, grad_h, - grad_conf_target, grad_conf_notarget, grad_class, - class_num, loss_weight_xy, loss_weight_wh, - loss_weight_conf_target, loss_weight_conf_notarget, - loss_weight_class); + T* input_grad_data = + input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, th, + tweight, tconf, tclass, conf_mask, obj_mask); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 5fb4588e0b..caa9b1c3d4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -416,11 +416,6 @@ def yolov3_loss(x, class_num, ignore_thresh, input_size, - loss_weight_xy=None, - loss_weight_wh=None, - loss_weight_conf_target=None, - loss_weight_conf_notarget=None, - loss_weight_class=None, name=None): """ ${comment} @@ -438,11 +433,6 @@ def yolov3_loss(x, class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} input_size (int): ${input_size_comment} - loss_weight_xy (float|None): ${loss_weight_xy_comment} - loss_weight_wh (float|None): ${loss_weight_wh_comment} - loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment} - loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment} - loss_weight_class (float|None): ${loss_weight_class_comment} name (string): the name of yolov3 loss Returns: @@ -495,18 +485,18 @@ def yolov3_loss(x, "input_size": input_size, } - if loss_weight_xy is not None and isinstance(loss_weight_xy, float): - self.attrs['loss_weight_xy'] = loss_weight_xy - if loss_weight_wh is not None and isinstance(loss_weight_wh, float): - self.attrs['loss_weight_wh'] = loss_weight_wh - if loss_weight_conf_target is not None and isinstance( - loss_weight_conf_target, float): - self.attrs['loss_weight_conf_target'] = loss_weight_conf_target - if loss_weight_conf_notarget is not None and isinstance( - loss_weight_conf_notarget, float): - self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget - if loss_weight_class is not None and isinstance(loss_weight_class, float): - self.attrs['loss_weight_class'] = loss_weight_class + # if loss_weight_xy is not None and isinstance(loss_weight_xy, float): + # self.attrs['loss_weight_xy'] = loss_weight_xy + # if loss_weight_wh is not None and isinstance(loss_weight_wh, float): + # self.attrs['loss_weight_wh'] = loss_weight_wh + # if loss_weight_conf_target is not None and isinstance( + # loss_weight_conf_target, float): + # self.attrs['loss_weight_conf_target'] = loss_weight_conf_target + # if loss_weight_conf_notarget is not None and isinstance( + # loss_weight_conf_notarget, float): + # self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget + # if loss_weight_class is not None and isinstance(loss_weight_class, float): + # self.attrs['loss_weight_class'] = loss_weight_class helper.append_op( type='yolov3_loss', diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 9cf398f18f..0fe836683b 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -470,8 +470,6 @@ class OpTest(unittest.TestCase): ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) - # print(numeric_grads[0][0, 4, :, :]) - # print(analytic_grads[0][0, 4, :, :]) self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check, max_relative_error, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index e218031286..cf7e2c5289 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -80,8 +80,8 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): class_num = attrs["class_num"] input_size = attrs["input_size"] an_num = len(anchors) // 2 + conf_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') @@ -114,10 +114,10 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): max_iou = iou best_an_index = k if iou > ignore_thresh: - noobj_mask[i, best_an_index, gj, gi] = 0 + conf_mask[i, best_an_index, gj, gi] = 0 + conf_mask[i, best_an_index, gj, gi] = 1 obj_mask[i, best_an_index, gj, gi] = 1 - noobj_mask[i, best_an_index, gj, gi] = 0 tx[i, best_an_index, gj, gi] = gx - gi ty[i, best_an_index, gj, gi] = gy - gj tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * @@ -129,7 +129,7 @@ def build_target(gtboxes, gtlabel, attrs, grid_size): tconf[i, best_an_index, gj, gi] = 1 tcls[i, best_an_index, gj, gi, gt_label] = 1 - return (tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask) + return (tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask) def YoloV3Loss(x, gtbox, gtlabel, attrs): @@ -144,11 +144,9 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): pred_conf = x[:, :, :, :, 4] pred_cls = x[:, :, :, :, 5:] - tx, ty, tw, th, tweight, tconf, tcls, obj_mask, noobj_mask = build_target( + tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask = build_target( gtbox, gtlabel, attrs, x.shape[2]) - # print("obj_mask: ", obj_mask[0, 0, :, :]) - # print("noobj_mask: ", noobj_mask[0, 0, :, :]) obj_weight = obj_mask * tweight obj_mask_expand = np.tile( np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) @@ -156,30 +154,19 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): loss_y = sce(pred_y, ty, obj_weight) loss_w = l1loss(pred_w, tw, obj_weight) loss_h = l1loss(pred_h, th, obj_weight) - loss_conf_target = sce(pred_conf, tconf, obj_mask) - loss_conf_notarget = sce(pred_conf, tconf, noobj_mask) + loss_obj = sce(pred_conf, tconf, conf_mask) loss_class = sce(pred_cls, tcls, obj_mask_expand) - # print("loss_xy: ", loss_x + loss_y) - # print("loss_wh: ", loss_w + loss_h) - # print("loss_conf_target: ", loss_conf_target) - # print("loss_conf_notarget: ", loss_conf_notarget) - # print("loss_class: ", loss_class) + # print("python loss_xy: ", loss_x + loss_y) + # print("python loss_wh: ", loss_w + loss_h) + # print("python loss_obj: ", loss_obj) + # print("python loss_class: ", loss_class) - return attrs['loss_weight_xy'] * (loss_x + loss_y) \ - + attrs['loss_weight_wh'] * (loss_w + loss_h) \ - + attrs['loss_weight_conf_target'] * loss_conf_target \ - + attrs['loss_weight_conf_notarget'] * loss_conf_notarget \ - + attrs['loss_weight_class'] * loss_class + return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class class TestYolov3LossOp(OpTest): def setUp(self): - self.loss_weight_xy = 1.0 - self.loss_weight_wh = 1.0 - self.loss_weight_conf_target = 1.0 - self.loss_weight_conf_notarget = 1.0 - self.loss_weight_class = 1.0 self.initTestCase() self.op_type = 'yolov3_loss' x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) @@ -192,11 +179,6 @@ class TestYolov3LossOp(OpTest): "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, "input_size": self.input_size, - "loss_weight_xy": self.loss_weight_xy, - "loss_weight_wh": self.loss_weight_wh, - "loss_weight_conf_target": self.loss_weight_conf_target, - "loss_weight_conf_notarget": self.loss_weight_conf_notarget, - "loss_weight_class": self.loss_weight_class, } self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} @@ -215,17 +197,12 @@ class TestYolov3LossOp(OpTest): max_relative_error=0.31) def initTestCase(self): - self.anchors = [12, 12] + self.anchors = [12, 12, 11, 13] self.class_num = 5 - self.ignore_thresh = 0.3 + self.ignore_thresh = 0.5 self.input_size = 416 self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 5, 4) - self.loss_weight_xy = 1.2 - self.loss_weight_wh = 0.8 - self.loss_weight_conf_target = 2.0 - self.loss_weight_conf_notarget = 1.0 - self.loss_weight_class = 1.5 if __name__ == "__main__": From 577a92d99203a67042f2b7fd6db25ecae09a1938 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 17 Dec 2018 11:45:16 +0800 Subject: [PATCH 25/53] use typename DeviceContext. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 12 +- paddle/fluid/operators/yolov3_loss_op.h | 301 ++++++------------ .../tests/unittests/test_yolov3_loss_op.py | 6 +- 3 files changed, 103 insertions(+), 216 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 3bd0db8b59..495a8f6c01 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -204,7 +204,11 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); -REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, - ops::Yolov3LossKernel); -REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, - ops::Yolov3LossGradKernel); +REGISTER_OP_CPU_KERNEL( + yolov3_loss, + ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL( + yolov3_loss_grad, + ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 5de5b4efc7..f086e89a99 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -13,6 +13,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -32,183 +33,6 @@ static inline bool isZero(T x) { return fabs(x) < 1e-6; } -template -static inline void CalcL1LossWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, - const T loss_weight, T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - loss[i] += fabs(y_data[j] - x_data[j]) * weight_data[j] * loss_weight; - } - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static void CalcL1LossGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& y, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = weight_data[j] * loss_grad[i]; - if (x_data[j] < y_data[j]) grad_data[j] *= -1.0; - } - grad_data += stride; - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcMSEWithWeight(const Tensor& x, const Tensor& y, - const Tensor& weight, const T loss_weight, - T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - loss[i] += pow(y_data[j] - x_data[j], 2) * weight_data[j] * loss_weight; - } - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static void CalcMSEGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& y, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* y_data = y.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = - 2.0 * weight_data[j] * (x_data[j] - y_data[j]) * loss_grad[i]; - } - grad_data += stride; - x_data += stride; - y_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcSCEWithWeight(const Tensor& x, const Tensor& label, - const Tensor& weight, const T loss_weight, - T* loss) { - int n = x.dims()[0]; - int stride = x.numel() / n; - const T* x_data = x.data(); - const T* label_data = label.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - T term1 = (x_data[j] > 0) ? x_data[j] : 0; - T term2 = x_data[j] * label_data[j]; - T term3 = std::log(1.0 + std::exp(-std::abs(x_data[j]))); - loss[i] += (term1 - term2 + term3) * weight_data[j] * loss_weight; - } - x_data += stride; - label_data += stride; - weight_data += stride; - } -} - -template -static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad, - const Tensor& x, const Tensor& label, - const Tensor& weight) { - int n = x.dims()[0]; - int stride = x.numel() / n; - T* grad_data = grad->data(); - const T* x_data = x.data(); - const T* label_data = label.data(); - const T* weight_data = weight.data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < stride; j++) { - grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) * - weight_data[j] * loss_grad[i]; - } - grad_data += stride; - x_data += stride; - label_data += stride; - weight_data += stride; - } -} - -// template -// static void SplitPredResult(const Tensor& input, Tensor* pred_conf, -// Tensor* pred_class, Tensor* pred_x, Tensor* -// pred_y, -// Tensor* pred_w, Tensor* pred_h, -// const int anchor_num, const int class_num) { -// const int n = input.dims()[0]; -// const int h = input.dims()[2]; -// const int w = input.dims()[3]; -// const int box_attr_num = 5 + class_num; -// -// auto input_t = EigenTensor::From(input); -// auto pred_conf_t = EigenTensor::From(*pred_conf); -// auto pred_class_t = EigenTensor::From(*pred_class); -// auto pred_x_t = EigenTensor::From(*pred_x); -// auto pred_y_t = EigenTensor::From(*pred_y); -// auto pred_w_t = EigenTensor::From(*pred_w); -// auto pred_h_t = EigenTensor::From(*pred_h); -// -// for (int i = 0; i < n; i++) { -// for (int an_idx = 0; an_idx < anchor_num; an_idx++) { -// for (int j = 0; j < h; j++) { -// for (int k = 0; k < w; k++) { -// pred_x_t(i, an_idx, j, k) = input_t(i, box_attr_num * an_idx, j, -// k); -// pred_y_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 1, j, k); -// pred_w_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 2, j, k); -// pred_h_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 3, j, k); -// -// pred_conf_t(i, an_idx, j, k) = -// input_t(i, box_attr_num * an_idx + 4, j, k); -// -// for (int c = 0; c < class_num; c++) { -// pred_class_t(i, an_idx, j, k, c) = -// input_t(i, box_attr_num * an_idx + 5 + c, j, k); -// } -// } -// } -// } -// } -// } - template static T CalcBoxIoU(std::vector box1, std::vector box2) { T b1_x1 = box1[0] - box1[2] / 2; @@ -242,30 +66,36 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, Tensor* tconf, Tensor* tclass) { const int n = gt_box.dims()[0]; const int b = gt_box.dims()[1]; - const int anchor_num = anchors.size() / 2; - auto gt_box_t = EigenTensor::From(gt_box); - auto gt_label_t = EigenTensor::From(gt_label); - auto conf_mask_t = EigenTensor::From(*conf_mask).setConstant(1.0); - auto obj_mask_t = EigenTensor::From(*obj_mask).setConstant(0.0); - auto tx_t = EigenTensor::From(*tx).setConstant(0.0); - auto ty_t = EigenTensor::From(*ty).setConstant(0.0); - auto tw_t = EigenTensor::From(*tw).setConstant(0.0); - auto th_t = EigenTensor::From(*th).setConstant(0.0); - auto tweight_t = EigenTensor::From(*tweight).setConstant(0.0); - auto tconf_t = EigenTensor::From(*tconf).setConstant(0.0); - auto tclass_t = EigenTensor::From(*tclass).setConstant(0.0); + const int an_num = anchors.size() / 2; + const int h = tclass->dims()[2]; + const int w = tclass->dims()[3]; + const int class_num = tclass->dims()[4]; + + const T* gt_box_data = gt_box.data(); + const int* gt_label_data = gt_label.data(); + T* conf_mask_data = conf_mask->data(); + T* obj_mask_data = obj_mask->data(); + T* tx_data = tx->data(); + T* ty_data = ty->data(); + T* tw_data = tw->data(); + T* th_data = th->data(); + T* tweight_data = tweight->data(); + T* tconf_data = tconf->data(); + T* tclass_data = tclass->data(); for (int i = 0; i < n; i++) { for (int j = 0; j < b; j++) { - if (isZero(gt_box_t(i, j, 2)) && isZero(gt_box_t(i, j, 3))) { + int box_idx = (i * b + j) * 4; + if (isZero(gt_box_data[box_idx + 2]) && + isZero(gt_box_data[box_idx + 3])) { continue; } - int cur_label = gt_label_t(i, j); - T gx = gt_box_t(i, j, 0) * grid_size; - T gy = gt_box_t(i, j, 1) * grid_size; - T gw = gt_box_t(i, j, 2) * input_size; - T gh = gt_box_t(i, j, 3) * input_size; + int cur_label = gt_label_data[i * b + j]; + T gx = gt_box_data[box_idx] * grid_size; + T gy = gt_box_data[box_idx + 1] * grid_size; + T gw = gt_box_data[box_idx + 2] * input_size; + T gh = gt_box_data[box_idx + 3] * input_size; int gi = static_cast(gx); int gj = static_cast(gy); @@ -273,7 +103,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, T iou; int best_an_index = -1; std::vector gt_box_shape({0, 0, gw, gh}); - for (int an_idx = 0; an_idx < anchor_num; an_idx++) { + for (int an_idx = 0; an_idx < an_num; an_idx++) { std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), static_cast(anchors[2 * an_idx + 1])}); iou = CalcBoxIoU(gt_box_shape, anchor_shape); @@ -282,19 +112,22 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, best_an_index = an_idx; } if (iou > ignore_thresh) { - conf_mask_t(i, an_idx, gj, gi) = static_cast(0.0); + int conf_idx = ((i * an_num + an_idx) * h + gj) * w + gi; + conf_mask_data[conf_idx] = static_cast(0.0); } } - conf_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - obj_mask_t(i, best_an_index, gj, gi) = static_cast(1.0); - tx_t(i, best_an_index, gj, gi) = gx - gi; - ty_t(i, best_an_index, gj, gi) = gy - gj; - tw_t(i, best_an_index, gj, gi) = log(gw / anchors[2 * best_an_index]); - th_t(i, best_an_index, gj, gi) = log(gh / anchors[2 * best_an_index + 1]); - tweight_t(i, best_an_index, gj, gi) = - 2.0 - gt_box_t(i, j, 2) * gt_box_t(i, j, 3); - tclass_t(i, best_an_index, gj, gi, cur_label) = 1; - tconf_t(i, best_an_index, gj, gi) = 1; + + int obj_idx = ((i * an_num + best_an_index) * h + gj) * w + gi; + conf_mask_data[obj_idx] = static_cast(1.0); + obj_mask_data[obj_idx] = static_cast(1.0); + tx_data[obj_idx] = gx - gi; + ty_data[obj_idx] = gy - gj; + tw_data[obj_idx] = log(gw / anchors[2 * best_an_index]); + th_data[obj_idx] = log(gh / anchors[2 * best_an_index + 1]); + tweight_data[obj_idx] = + 2.0 - gt_box_data[box_idx + 2] * gt_box_data[box_idx + 3]; + tconf_data[obj_idx] = static_cast(1.0); + tclass_data[obj_idx * class_num + cur_label] = static_cast(1.0); } } } @@ -427,18 +260,26 @@ static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, const int class_num = tclass.dims()[4]; const int grid_num = h * w; + // T l = 0.0; CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); + // LOG(ERROR) << "C++ xy: " << loss_data[0] - l; + // l = loss_data[0]; CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); + // LOG(ERROR) << "C++ wh: " << loss_data[0] - l; + // l = loss_data[0]; CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, conf_mask_data, n, an_num, grid_num, class_num, 1); + // LOG(ERROR) << "C++ conf: " << loss_data[0] - l; + // l = loss_data[0]; CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, obj_mask_data, n, an_num, grid_num, class_num, class_num); + // LOG(ERROR) << "C++ class: " << loss_data[0] - l; } template @@ -488,7 +329,7 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } -template +template class Yolov3LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -517,6 +358,27 @@ class Yolov3LossKernel : public framework::OpKernel { tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + + math::SetConstant constant; + constant(ctx.template device_context(), &conf_mask, + static_cast(1.0)); + constant(ctx.template device_context(), &obj_mask, + static_cast(0.0)); + constant(ctx.template device_context(), &tx, + static_cast(0.0)); + constant(ctx.template device_context(), &ty, + static_cast(0.0)); + constant(ctx.template device_context(), &tw, + static_cast(0.0)); + constant(ctx.template device_context(), &th, + static_cast(0.0)); + constant(ctx.template device_context(), &tweight, + static_cast(0.0)); + constant(ctx.template device_context(), &tconf, + static_cast(0.0)); + constant(ctx.template device_context(), &tclass, + static_cast(0.0)); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); @@ -528,7 +390,7 @@ class Yolov3LossKernel : public framework::OpKernel { } }; -template +template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -559,6 +421,27 @@ class Yolov3LossGradKernel : public framework::OpKernel { tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + + math::SetConstant constant; + constant(ctx.template device_context(), &conf_mask, + static_cast(1.0)); + constant(ctx.template device_context(), &obj_mask, + static_cast(0.0)); + constant(ctx.template device_context(), &tx, + static_cast(0.0)); + constant(ctx.template device_context(), &ty, + static_cast(0.0)); + constant(ctx.template device_context(), &tw, + static_cast(0.0)); + constant(ctx.template device_context(), &th, + static_cast(0.0)); + constant(ctx.template device_context(), &tweight, + static_cast(0.0)); + constant(ctx.template device_context(), &tconf, + static_cast(0.0)); + constant(ctx.template device_context(), &tclass, + static_cast(0.0)); + PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, &tconf, &tclass); diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index cf7e2c5289..862e77e663 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -197,12 +197,12 @@ class TestYolov3LossOp(OpTest): max_relative_error=0.31) def initTestCase(self): - self.anchors = [12, 12, 11, 13] + self.anchors = [12, 12] self.class_num = 5 self.ignore_thresh = 0.5 self.input_size = 416 - self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 5, 4) + self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3) + self.gtbox_shape = (1, 5, 4) if __name__ == "__main__": From db8ff57a61cbeec30b61111850b3e768661e8de8 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 17 Dec 2018 14:43:06 +0800 Subject: [PATCH 26/53] remove useless code and update doc. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 32 +++++----- paddle/fluid/operators/yolov3_loss_op.h | 64 ++++++++----------- python/paddle/fluid/layers/detection.py | 13 ---- .../tests/unittests/test_yolov3_loss_op.py | 5 -- 4 files changed, 45 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 495a8f6c01..aa4ba3b62e 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -138,17 +138,23 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { thresh, the confidence score loss of this anchor box will be ignored. Therefore, the yolov3 loss consist of three major parts, box location loss, - confidence score loss, and classification loss. The MSE loss is used for - box location, and binary cross entropy loss is used for confidence score - loss and classification loss. + confidence score loss, and classification loss. The L1 loss is used for + box coordinates (w, h), and sigmoid cross entropy loss is used for box + coordinates (x, y), confidence score loss and classification loss. + + In order to trade off box coordinate losses between big boxes and small + boxes, box coordinate losses will be mutiplied by scale weight, which is + calculated as follow. + + $$ + weight_{box} = 2.0 - t_w * t_h + $$ Final loss will be represented as follow. $$ - loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh} - + \loss_weight_{conf_target} * loss_{conf_target} - + \loss_weight_{conf_notarget} * loss_{conf_notarget} - + \loss_weight_{class} * loss_{class} + loss = (loss_{xy} + loss_{wh}) * weight_{box} + + loss_{conf} + loss_{class} $$ )DOC"); } @@ -204,11 +210,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, ops::Yolov3LossGradMaker); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); -REGISTER_OP_CPU_KERNEL( - yolov3_loss, - ops::Yolov3LossKernel, - ops::Yolov3LossKernel); -REGISTER_OP_CPU_KERNEL( - yolov3_loss_grad, - ops::Yolov3LossGradKernel, - ops::Yolov3LossGradKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel, + ops::Yolov3LossKernel); +REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel, + ops::Yolov3LossGradKernel); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index f086e89a99..e32cd30967 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -260,26 +260,18 @@ static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, const int class_num = tclass.dims()[4]; const int grid_num = h * w; - // T l = 0.0; CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num, 1); - // LOG(ERROR) << "C++ xy: " << loss_data[0] - l; - // l = loss_data[0]; CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, obj_mask_data, n, an_num, grid_num, class_num); - // LOG(ERROR) << "C++ wh: " << loss_data[0] - l; - // l = loss_data[0]; CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, conf_mask_data, n, an_num, grid_num, class_num, 1); - // LOG(ERROR) << "C++ conf: " << loss_data[0] - l; - // l = loss_data[0]; CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, obj_mask_data, n, an_num, grid_num, class_num, class_num); - // LOG(ERROR) << "C++ class: " << loss_data[0] - l; } template @@ -329,7 +321,7 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } -template +template class Yolov3LossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -359,24 +351,24 @@ class Yolov3LossKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - math::SetConstant constant; - constant(ctx.template device_context(), &conf_mask, - static_cast(1.0)); - constant(ctx.template device_context(), &obj_mask, - static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, + math::SetConstant constant; + constant(ctx.template device_context(), + &conf_mask, static_cast(1.0)); + constant(ctx.template device_context(), + &obj_mask, static_cast(0.0)); + constant(ctx.template device_context(), &tx, static_cast(0.0)); - constant(ctx.template device_context(), &tw, + constant(ctx.template device_context(), &ty, static_cast(0.0)); - constant(ctx.template device_context(), &th, + constant(ctx.template device_context(), &tw, static_cast(0.0)); - constant(ctx.template device_context(), &tweight, + constant(ctx.template device_context(), &th, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, + constant(ctx.template device_context(), + &tweight, static_cast(0.0)); + constant(ctx.template device_context(), &tconf, static_cast(0.0)); - constant(ctx.template device_context(), &tclass, + constant(ctx.template device_context(), &tclass, static_cast(0.0)); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, @@ -390,7 +382,7 @@ class Yolov3LossKernel : public framework::OpKernel { } }; -template +template class Yolov3LossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -422,24 +414,24 @@ class Yolov3LossGradKernel : public framework::OpKernel { tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - math::SetConstant constant; - constant(ctx.template device_context(), &conf_mask, - static_cast(1.0)); - constant(ctx.template device_context(), &obj_mask, - static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, + math::SetConstant constant; + constant(ctx.template device_context(), + &conf_mask, static_cast(1.0)); + constant(ctx.template device_context(), + &obj_mask, static_cast(0.0)); + constant(ctx.template device_context(), &tx, static_cast(0.0)); - constant(ctx.template device_context(), &tw, + constant(ctx.template device_context(), &ty, static_cast(0.0)); - constant(ctx.template device_context(), &th, + constant(ctx.template device_context(), &tw, static_cast(0.0)); - constant(ctx.template device_context(), &tweight, + constant(ctx.template device_context(), &th, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, + constant(ctx.template device_context(), + &tweight, static_cast(0.0)); + constant(ctx.template device_context(), &tconf, static_cast(0.0)); - constant(ctx.template device_context(), &tclass, + constant(ctx.template device_context(), &tclass, static_cast(0.0)); PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index caa9b1c3d4..92823af1e0 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -485,19 +485,6 @@ def yolov3_loss(x, "input_size": input_size, } - # if loss_weight_xy is not None and isinstance(loss_weight_xy, float): - # self.attrs['loss_weight_xy'] = loss_weight_xy - # if loss_weight_wh is not None and isinstance(loss_weight_wh, float): - # self.attrs['loss_weight_wh'] = loss_weight_wh - # if loss_weight_conf_target is not None and isinstance( - # loss_weight_conf_target, float): - # self.attrs['loss_weight_conf_target'] = loss_weight_conf_target - # if loss_weight_conf_notarget is not None and isinstance( - # loss_weight_conf_notarget, float): - # self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget - # if loss_weight_class is not None and isinstance(loss_weight_class, float): - # self.attrs['loss_weight_class'] = loss_weight_class - helper.append_op( type='yolov3_loss', inputs={"X": x, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 862e77e663..e52047b0ad 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -157,11 +157,6 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): loss_obj = sce(pred_conf, tconf, conf_mask) loss_class = sce(pred_cls, tcls, obj_mask_expand) - # print("python loss_xy: ", loss_x + loss_y) - # print("python loss_wh: ", loss_w + loss_h) - # print("python loss_obj: ", loss_obj) - # print("python loss_class: ", loss_class) - return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class From bd6deb1a8bc0b39cde425117b6c6048f4a945a7f Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 17 Dec 2018 15:09:56 +0800 Subject: [PATCH 27/53] fix API.spec change. test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 4acccd0899..f293b0d30e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'input_size', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) From e7e4f084e51a3f3a91a32b9eb03bff71963f9e45 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 20 Dec 2018 21:34:05 +0800 Subject: [PATCH 28/53] ignore pred overlap gt > 0.7. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 35 +- paddle/fluid/operators/yolov3_loss_op.h | 556 +++++++++++++++--- python/paddle/fluid/layers/detection.py | 14 +- python/paddle/fluid/tests/test_detection.py | 4 +- .../tests/unittests/test_yolov3_loss_op.py | 184 +++++- 5 files changed, 668 insertions(+), 125 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index aa4ba3b62e..8c46e341d6 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -35,13 +35,16 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_gtlabel = ctx->GetInputDim("GTLabel"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; + auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); + int mask_num = anchor_mask.size(); auto class_num = ctx->Attrs().Get("class_num"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); - PADDLE_ENFORCE_EQ(dim_x[1], anchor_num * (5 + class_num), - "Input(X) dim[1] should be equal to (anchor_number * (5 " - "+ class_num))."); + PADDLE_ENFORCE_EQ( + dim_x[1], mask_num * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " + "+ class_num))."); PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3, "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); @@ -55,6 +58,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, "Attr(anchors) length should be even integer."); + for (size_t i = 0; i < anchor_mask.size(); i++) { + PADDLE_ENFORCE_LT( + anchor_mask[i], anchor_num, + "Attr(anchor_mask) should not crossover Attr(anchors)."); + } PADDLE_ENFORCE_GT(class_num, 0, "Attr(class_num) should be an integer greater then 0."); @@ -74,7 +82,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "The input tensor of YOLO v3 loss operator, " + "The input tensor of YOLOv3 loss operator, " "This is a 4-D tensor with shape of [N, C, H, W]." "H and W should be same, and the second dimention(C) stores" "box locations, confidence score and classification one-hot" @@ -99,13 +107,20 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", "The anchor width and height, " - "it will be parsed pair by pair."); - AddAttr("input_size", - "The input size of YOLOv3 net, " - "generally this is set as 320, 416 or 608.") - .SetDefault(406); + "it will be parsed pair by pair.") + .SetDefault(std::vector{}); + AddAttr>("anchor_mask", + "The mask index of anchors used in " + "current YOLOv3 loss calculation.") + .SetDefault(std::vector{}); + AddAttr("downsample", + "The downsample ratio from network input to YOLOv3 loss " + "input, so 32, 16, 8 should be set for the first, second, " + "and thrid YOLOv3 loss operators.") + .SetDefault(32); AddAttr("ignore_thresh", - "The ignore threshold to ignore confidence loss."); + "The ignore threshold to ignore confidence loss.") + .SetDefault(0.7); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index e32cd30967..9254a6cf6f 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -321,6 +321,182 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, obj_mask_data, n, an_num, grid_num, class_num, class_num); } +static int mask_index(std::vector mask, int val) { + for (int i = 0; i < mask.size(); i++) { + if (mask[i] == val) { + return i; + } + } + return -1; +} + +template +struct Box { + float x, y, w, h; +}; + +template +static inline T sigmoid(T x) { + return 1.0 / (1.0 + std::exp(-x)); +} + +template +static inline void sigmoid_arrray(T* arr, int len) { + for (int i = 0; i < len; i++) { + arr[i] = sigmoid(arr[i]); + } +} + +template +static inline Box get_yolo_box(const T* x, std::vector anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride) { + Box b; + b.x = (i + sigmoid(x[index])) / grid_size; + b.y = (j + sigmoid(x[index + stride])) / grid_size; + b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] / input_size; + b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] / input_size; + return b; +} + +template +static inline Box get_gt_box(const T* gt, int batch, int max_boxes, + int idx) { + Box b; + b.x = gt[(batch * max_boxes + idx) * 4]; + b.y = gt[(batch * max_boxes + idx) * 4 + 1]; + b.w = gt[(batch * max_boxes + idx) * 4 + 2]; + b.h = gt[(batch * max_boxes + idx) * 4 + 3]; + return b; +} + +template +static inline T overlap(T c1, T w1, T c2, T w2) { + T l1 = c1 - w1 / 2.0; + T l2 = c2 - w2 / 2.0; + T left = l1 > l2 ? l1 : l2; + T r1 = c1 + w1 / 2.0; + T r2 = c2 + w2 / 2.0; + T right = r1 < r2 ? r1 : r2; + return right - left; +} + +template +static inline T box_iou(Box b1, Box b2) { + T w = overlap(b1.x, b1.w, b2.x, b2.w); + T h = overlap(b1.y, b1.h, b2.y, b2.h); + T inter_area = (w < 0 || h < 0) ? 0.0 : w * h; + T union_area = b1.w * b1.h + b2.w * b2.h - inter_area; + return inter_area / union_area; +} + +static inline int entry_index(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride, int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +template +static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, + std::vector anchors, int an_idx, + int box_idx, int gi, int gj, int grid_size, + int input_size, int stride) { + T tx = gt.x * grid_size - gi; + T ty = gt.y * grid_size - gj; + T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); + T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); + + T scale = 2.0 - gt.w * gt.h; + loss[0] += SCE(input[box_idx], tx) * scale; + loss[0] += SCE(input[box_idx + stride], ty) * scale; + loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; + loss[0] += L1Loss(input[box_idx + 3 * stride], th) * scale; +} + +template +static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, + Box gt, std::vector anchors, + int an_idx, int box_idx, int gi, int gj, + int grid_size, int input_size, int stride) { + T tx = gt.x * grid_size - gi; + T ty = gt.y * grid_size - gj; + T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); + T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); + + T scale = 2.0 - gt.w * gt.h; + input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; + input_grad[box_idx + stride] = + SCEGrad(input[box_idx + stride], ty) * scale * loss; + input_grad[box_idx + 2 * stride] = + L1LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; + input_grad[box_idx + 3 * stride] = + L1LossGrad(input[box_idx + 3 * stride], th) * scale * loss; +} + +template +static inline void CalcLabelLoss(T* loss, const T* input, const int index, + const int label, const int class_num, + const int stride) { + for (int i = 0; i < class_num; i++) { + loss[0] += SCE(input[index + i * stride], (i == label) ? 1.0 : 0.0); + } +} + +template +static inline void CalcLabelLossGrad(T* input_grad, const T loss, + const T* input, const int index, + const int label, const int class_num, + const int stride) { + for (int i = 0; i < class_num; i++) { + input_grad[index + i * stride] = + SCEGrad(input[index + i * stride], (i == label) ? 1.0 : 0.0) * loss; + } +} + +template +static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, + const int n, const int an_num, const int h, + const int w, const int stride, + const int an_stride) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj = objness[k * w + l]; + if (obj >= 0) { + loss[i] += SCE(input[k * w + l], static_cast(obj)); + } + } + } + objness += stride; + input += an_stride; + } + } +} + +template +static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, + const T* input, const int* objness, + const int n, const int an_num, + const int h, const int w, + const int stride, const int an_stride) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj = objness[k * w + l]; + if (obj >= 0) { + input_grad[k * w + l] = + SCEGrad(input[k * w + l], static_cast(obj)) * loss[i]; + } + } + } + objness += stride; + input += an_stride; + input_grad += an_stride; + } + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -330,55 +506,158 @@ class Yolov3LossKernel : public framework::OpKernel { auto* gt_label = ctx.Input("GTLabel"); auto* loss = ctx.Output("Loss"); auto anchors = ctx.Attr>("anchors"); + auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - int input_size = ctx.Attr("input_size"); float ignore_thresh = ctx.Attr("ignore_thresh"); + int downsample = ctx.Attr("downsample"); const int n = input->dims()[0]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; + const int mask_num = anchor_mask.size(); + const int b = gt_box->dims()[1]; + int input_size = downsample * h; - Tensor conf_mask, obj_mask; - Tensor tx, ty, tw, th, tweight, tconf, tclass; - conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - - math::SetConstant constant; - constant(ctx.template device_context(), - &conf_mask, static_cast(1.0)); - constant(ctx.template device_context(), - &obj_mask, static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, - static_cast(0.0)); - constant(ctx.template device_context(), &tw, - static_cast(0.0)); - constant(ctx.template device_context(), &th, - static_cast(0.0)); - constant(ctx.template device_context(), - &tweight, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, - static_cast(0.0)); - constant(ctx.template device_context(), &tclass, - static_cast(0.0)); - - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, - &tconf, &tclass); - + const T* input_data = input->data(); + const T* gt_box_data = gt_box->data(); + const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); - memset(loss_data, 0, n * sizeof(T)); - CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, tclass, - conf_mask, obj_mask); + memset(loss_data, 0, n * sizeof(int)); + + Tensor objness; + int* objness_data = + objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(objness_data, 0, objness.numel() * sizeof(int)); + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < mask_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int box_idx = + entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = + get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, + input_size, box_idx, stride); + T best_iou = 0; + // int best_t = 0; + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + T iou = box_iou(pred, gt); + if (iou > best_iou) { + best_iou = iou; + // best_t = t; + } + } + + if (best_iou > ignore_thresh) { + int obj_idx = (i * mask_num + j) * stride + k * w + l; + objness_data[obj_idx] = -1; + } + } + } + } + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + Box gt_shift = gt; + gt_shift.x = 0.0; + gt_shift.y = 0.0; + T best_iou = 0.0; + int best_n = 0; + for (int an_idx = 0; an_idx < an_num; an_idx++) { + Box an_box; + an_box.x = 0.0; + an_box.y = 0.0; + an_box.w = anchors[2 * an_idx] / static_cast(input_size); + an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); + float iou = box_iou(an_box, gt_shift); + // TO DO: iou > 0.5 ? + if (iou > best_iou) { + best_iou = iou; + best_n = an_idx; + } + } + + int mask_idx = mask_index(anchor_mask, best_n); + if (mask_idx >= 0) { + int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); + CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, + box_idx, gi, gj, h, input_size, stride); + + int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; + objness_data[obj_idx] = 1; + + int label = gt_label_data[i * b + t]; + int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); + CalcLabelLoss(loss_data + i, input_data, label_idx, label, + class_num, stride); + } + } + } + + CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, + mask_num, h, w, stride, an_stride); + + // Tensor conf_mask, obj_mask; + // Tensor tx, ty, tw, th, tweight, tconf, tclass; + // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + // + // math::SetConstant constant; + // constant(ctx.template device_context(), + // &conf_mask, static_cast(1.0)); + // constant(ctx.template device_context(), + // &obj_mask, static_cast(0.0)); + // constant(ctx.template device_context(), &tx, + // static_cast(0.0)); + // constant(ctx.template device_context(), &ty, + // static_cast(0.0)); + // constant(ctx.template device_context(), &tw, + // static_cast(0.0)); + // constant(ctx.template device_context(), &th, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tweight, static_cast(0.0)); + // constant(ctx.template device_context(), + // &tconf, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tclass, + // static_cast(0.0)); + // + // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, + // input_size, + // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, + // &tweight, + // &tconf, &tclass); + // + // T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); + // memset(loss_data, 0, n * sizeof(T)); + // CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, + // tclass, + // conf_mask, obj_mask); } }; @@ -389,59 +668,172 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto anchors = ctx.Attr>("anchors"); + auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); - int input_size = ctx.Attr("input_size"); + int downsample = ctx.Attr("downsample"); const int n = input->dims()[0]; const int c = input->dims()[1]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int an_num = anchors.size() / 2; - - Tensor conf_mask, obj_mask; - Tensor tx, ty, tw, th, tweight, tconf, tclass; - conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - - math::SetConstant constant; - constant(ctx.template device_context(), - &conf_mask, static_cast(1.0)); - constant(ctx.template device_context(), - &obj_mask, static_cast(0.0)); - constant(ctx.template device_context(), &tx, - static_cast(0.0)); - constant(ctx.template device_context(), &ty, - static_cast(0.0)); - constant(ctx.template device_context(), &tw, - static_cast(0.0)); - constant(ctx.template device_context(), &th, - static_cast(0.0)); - constant(ctx.template device_context(), - &tweight, static_cast(0.0)); - constant(ctx.template device_context(), &tconf, - static_cast(0.0)); - constant(ctx.template device_context(), &tclass, - static_cast(0.0)); - - PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, input_size, - h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, &tweight, - &tconf, &tclass); - + const int mask_num = anchor_mask.size(); + const int b = gt_box->dims()[1]; + int input_size = downsample * h; + + const T* input_data = input->data(); + const T* gt_box_data = gt_box->data(); + const int* gt_label_data = gt_label->data(); + const T* loss_grad_data = loss_grad->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, th, - tweight, tconf, tclass, conf_mask, obj_mask); + memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); + + Tensor objness; + int* objness_data = + objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(objness_data, 0, objness.numel() * sizeof(int)); + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < mask_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int box_idx = + entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = + get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, + input_size, box_idx, stride); + T best_iou = 0; + // int best_t = 0; + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + T iou = box_iou(pred, gt); + if (iou > best_iou) { + best_iou = iou; + // best_t = t; + } + } + + if (best_iou > ignore_thresh) { + int obj_idx = (i * mask_num + j) * stride + k * w + l; + objness_data[obj_idx] = -1; + } + } + } + } + for (int t = 0; t < b; t++) { + if (isZero(gt_box_data[i * b * 4 + t * 4]) && + isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + continue; + } + Box gt = get_gt_box(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + Box gt_shift = gt; + gt_shift.x = 0.0; + gt_shift.y = 0.0; + T best_iou = 0.0; + int best_n = 0; + for (int an_idx = 0; an_idx < an_num; an_idx++) { + Box an_box; + an_box.x = 0.0; + an_box.y = 0.0; + an_box.w = anchors[2 * an_idx] / static_cast(input_size); + an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); + float iou = box_iou(an_box, gt_shift); + // TO DO: iou > 0.5 ? + if (iou > best_iou) { + best_iou = iou; + best_n = an_idx; + } + } + + int mask_idx = mask_index(anchor_mask, best_n); + if (mask_idx >= 0) { + int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); + CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], + input_data, gt, anchors, best_n, box_idx, + gi, gj, h, input_size, stride); + + int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; + objness_data[obj_idx] = 1; + + int label = gt_label_data[i * b + t]; + int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); + CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, + label_idx, label, class_num, stride); + } + } + } + + CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, + input_data + 4 * stride, objness_data, n, mask_num, + h, w, stride, an_stride); + + // const int n = input->dims()[0]; + // const int c = input->dims()[1]; + // const int h = input->dims()[2]; + // const int w = input->dims()[3]; + // const int an_num = anchors.size() / 2; + // + // Tensor conf_mask, obj_mask; + // Tensor tx, ty, tw, th, tweight, tconf, tclass; + // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); + // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); + // + // math::SetConstant constant; + // constant(ctx.template device_context(), + // &conf_mask, static_cast(1.0)); + // constant(ctx.template device_context(), + // &obj_mask, static_cast(0.0)); + // constant(ctx.template device_context(), &tx, + // static_cast(0.0)); + // constant(ctx.template device_context(), &ty, + // static_cast(0.0)); + // constant(ctx.template device_context(), &tw, + // static_cast(0.0)); + // constant(ctx.template device_context(), &th, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tweight, static_cast(0.0)); + // constant(ctx.template device_context(), + // &tconf, + // static_cast(0.0)); + // constant(ctx.template device_context(), + // &tclass, + // static_cast(0.0)); + // + // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, + // input_size, + // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, + // &tweight, + // &tconf, &tclass); + // + // T* input_grad_data = + // input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); + // CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, + // th, + // tweight, tconf, tclass, conf_mask, obj_mask); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 92823af1e0..542162b7f4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -413,9 +413,10 @@ def yolov3_loss(x, gtbox, gtlabel, anchors, + anchor_mask, class_num, ignore_thresh, - input_size, + downsample, name=None): """ ${comment} @@ -430,9 +431,10 @@ def yolov3_loss(x, gtlabel (Variable): class id of ground truth boxes, shoud be ins shape of [N, B]. anchors (list|tuple): ${anchors_comment} + anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} - input_size (int): ${input_size_comment} + downsample (int): ${downsample_comment} name (string): the name of yolov3 loss Returns: @@ -452,7 +454,8 @@ def yolov3_loss(x, x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32') gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') - anchors = [10, 13, 16, 30, 33, 23] + anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] + anchors = [0, 1, 2] loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 anchors=anchors, ignore_thresh=0.5) """ @@ -466,6 +469,8 @@ def yolov3_loss(x, raise TypeError("Input gtlabel of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): + raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") if not isinstance(ignore_thresh, float): @@ -480,9 +485,10 @@ def yolov3_loss(x, attrs = { "anchors": anchors, + "anchor_mask": anchor_mask, "class_num": class_num, "ignore_thresh": ignore_thresh, - "input_size": input_size, + "downsample": downsample, } helper.append_op( diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 7d75562900..e11205d2bf 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -463,8 +463,8 @@ class TestYoloDetection(unittest.TestCase): x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32') gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') - loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10, - 0.7, 416) + loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], + [0, 1], 10, 0.7, 32) self.assertIsNotNone(loss) diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index e52047b0ad..3cada49647 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -22,32 +22,42 @@ from op_test import OpTest from paddle.fluid import core - -def l1loss(x, y, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - y = y.reshape((n, -1)) - weight = weight.reshape((n, -1)) - return (np.abs(y - x) * weight).sum(axis=1) +# def l1loss(x, y, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# y = y.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# return (np.abs(y - x) * weight).sum(axis=1) +# +# +# def mse(x, y, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# y = y.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# return ((y - x)**2 * weight).sum(axis=1) +# +# +# def sce(x, label, weight): +# n = x.shape[0] +# x = x.reshape((n, -1)) +# label = label.reshape((n, -1)) +# weight = weight.reshape((n, -1)) +# sigmoid_x = expit(x) +# term1 = label * np.log(sigmoid_x) +# term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) +# return ((-term1 - term2) * weight).sum(axis=1) -def mse(x, y, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - y = y.reshape((n, -1)) - weight = weight.reshape((n, -1)) - return ((y - x)**2 * weight).sum(axis=1) +def l1loss(x, y): + return abs(x - y) -def sce(x, label, weight): - n = x.shape[0] - x = x.reshape((n, -1)) - label = label.reshape((n, -1)) - weight = weight.reshape((n, -1)) +def sce(x, label): sigmoid_x = expit(x) term1 = label * np.log(sigmoid_x) term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) - return ((-term1 - term2) * weight).sum(axis=1) + return -term1 - term2 def box_iou(box1, box2): @@ -160,6 +170,121 @@ def YoloV3Loss(x, gtbox, gtlabel, attrs): return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def batch_xywh_box_iou(box1, box2): + b1_left = box1[:, :, 0] - box1[:, :, 2] / 2 + b1_right = box1[:, :, 0] + box1[:, :, 2] / 2 + b1_top = box1[:, :, 1] - box1[:, :, 3] / 2 + b1_bottom = box1[:, :, 1] + box1[:, :, 3] / 2 + + b2_left = box2[:, :, 0] - box2[:, :, 2] / 2 + b2_right = box2[:, :, 0] + box2[:, :, 2] / 2 + b2_top = box2[:, :, 1] - box2[:, :, 3] / 2 + b2_bottom = box2[:, :, 1] + box2[:, :, 3] / 2 + + left = np.maximum(b1_left[:, :, np.newaxis], b2_left[:, np.newaxis, :]) + right = np.minimum(b1_right[:, :, np.newaxis], b2_right[:, np.newaxis, :]) + top = np.maximum(b1_top[:, :, np.newaxis], b2_top[:, np.newaxis, :]) + bottom = np.minimum(b1_bottom[:, :, np.newaxis], + b2_bottom[:, np.newaxis, :]) + + inter_w = np.clip(right - left, 0., 1.) + inter_h = np.clip(bottom - top, 0., 1.) + inter_area = inter_w * inter_h + + b1_area = (b1_right - b1_left) * (b1_bottom - b1_top) + b2_area = (b2_right - b2_left) * (b2_bottom - b2_top) + union = b1_area[:, :, np.newaxis] + b2_area[:, np.newaxis, :] - inter_area + + return inter_area / union + + +def YOLOv3Loss(x, gtbox, gtlabel, attrs): + n, c, h, w = x.shape + b = gtbox.shape[1] + anchors = attrs['anchors'] + an_num = len(anchors) // 2 + anchor_mask = attrs['anchor_mask'] + mask_num = len(anchor_mask) + class_num = attrs["class_num"] + ignore_thresh = attrs['ignore_thresh'] + downsample = attrs['downsample'] + input_size = downsample * h + x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + loss = np.zeros((n)).astype('float32') + + pred_box = x[:, :, :, :, :4].copy() + grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) + grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) + pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w + pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + + mask_anchors = [] + for m in anchor_mask: + mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in mask_anchors]) + anchor_w = anchors_s[:, 0:1].reshape((1, mask_num, 1, 1)) + anchor_h = anchors_s[:, 1:2].reshape((1, mask_num, 1, 1)) + pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w + pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h + + pred_box = pred_box.reshape((n, -1, 4)) + pred_obj = x[:, :, :, :, 4].reshape((n, -1)) + objness = np.zeros(pred_box.shape[:2]) + ious = batch_xywh_box_iou(pred_box, gtbox) + ious_max = np.max(ious, axis=-1) + objness = np.where(ious_max > ignore_thresh, -np.ones_like(objness), + objness) + + gtbox_shift = gtbox.copy() + gtbox_shift[:, :, 0] = 0 + gtbox_shift[:, :, 1] = 0 + + anchors = [(anchors[2 * i], anchors[2 * i + 1]) for i in range(0, an_num)] + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in anchors]) + anchor_boxes = np.concatenate( + [np.zeros_like(anchors_s), anchors_s], axis=-1) + anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1)) + ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes) + iou_matches = np.argmax(ious, axis=-1) + for i in range(n): + for j in range(b): + if gtbox[i, j, 2:].sum() == 0: + continue + if iou_matches[i, j] not in anchor_mask: + continue + an_idx = anchor_mask.index(iou_matches[i, j]) + gi = int(gtbox[i, j, 0] * w) + gj = int(gtbox[i, j, 1] * h) + + tx = gtbox[i, j, 0] * w - gi + ty = gtbox[i, j, 1] * w - gj + tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) + th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) + scale = 2.0 - gtbox[i, j, 2] * gtbox[i, j, 3] + loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale + loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale + loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale + + objness[i, an_idx * h * w + gj * w + gi] = 1 + + for label_idx in range(class_num): + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], + int(label_idx == gtlabel[i, j])) + + for j in range(mask_num * h * w): + if objness[i, j] >= 0: + loss[i] += sce(pred_obj[i, j], objness[i, j]) + + return loss + + class TestYolov3LossOp(OpTest): def setUp(self): self.initTestCase() @@ -171,13 +296,14 @@ class TestYolov3LossOp(OpTest): self.attrs = { "anchors": self.anchors, + "anchor_mask": self.anchor_mask, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, - "input_size": self.input_size, + "downsample": self.downsample, } self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} - self.outputs = {'Loss': YoloV3Loss(x, gtbox, gtlabel, self.attrs)} + self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} def test_check_output(self): place = core.CPUPlace() @@ -189,15 +315,19 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.31) + max_relative_error=0.15) def initTestCase(self): - self.anchors = [12, 12] + self.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, + 373, 326 + ] + self.anchor_mask = [0, 1, 2] self.class_num = 5 - self.ignore_thresh = 0.5 - self.input_size = 416 - self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3) - self.gtbox_shape = (1, 5, 4) + self.ignore_thresh = 0.7 + self.downsample = 32 + self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 10, 4) if __name__ == "__main__": From 6c5a5d078920d7be79e5346e5cc6870b1b6b3aa3 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 21 Dec 2018 12:13:57 +0800 Subject: [PATCH 29/53] format code. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.h | 472 ++---------------- .../tests/unittests/test_yolov3_loss_op.py | 148 +----- 3 files changed, 53 insertions(+), 569 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f293b0d30e..6c6ac9c7ea 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'input_size', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 9254a6cf6f..12499befca 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -26,110 +26,9 @@ template using EigenVector = framework::EigenVector; -using Array5 = Eigen::DSizes; - -template -static inline bool isZero(T x) { - return fabs(x) < 1e-6; -} - template -static T CalcBoxIoU(std::vector box1, std::vector box2) { - T b1_x1 = box1[0] - box1[2] / 2; - T b1_x2 = box1[0] + box1[2] / 2; - T b1_y1 = box1[1] - box1[3] / 2; - T b1_y2 = box1[1] + box1[3] / 2; - T b2_x1 = box2[0] - box2[2] / 2; - T b2_x2 = box2[0] + box2[2] / 2; - T b2_y1 = box2[1] - box2[3] / 2; - T b2_y2 = box2[1] + box2[3] / 2; - - T b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1); - T b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1); - - T inter_rect_x1 = std::max(b1_x1, b2_x1); - T inter_rect_y1 = std::max(b1_y1, b2_y1); - T inter_rect_x2 = std::min(b1_x2, b2_x2); - T inter_rect_y2 = std::min(b1_y2, b2_y2); - T inter_area = std::max(inter_rect_x2 - inter_rect_x1, static_cast(0.0)) * - std::max(inter_rect_y2 - inter_rect_y1, static_cast(0.0)); - - return inter_area / (b1_area + b2_area - inter_area); -} - -template -static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label, - const float ignore_thresh, std::vector anchors, - const int input_size, const int grid_size, - Tensor* conf_mask, Tensor* obj_mask, Tensor* tx, - Tensor* ty, Tensor* tw, Tensor* th, Tensor* tweight, - Tensor* tconf, Tensor* tclass) { - const int n = gt_box.dims()[0]; - const int b = gt_box.dims()[1]; - const int an_num = anchors.size() / 2; - const int h = tclass->dims()[2]; - const int w = tclass->dims()[3]; - const int class_num = tclass->dims()[4]; - - const T* gt_box_data = gt_box.data(); - const int* gt_label_data = gt_label.data(); - T* conf_mask_data = conf_mask->data(); - T* obj_mask_data = obj_mask->data(); - T* tx_data = tx->data(); - T* ty_data = ty->data(); - T* tw_data = tw->data(); - T* th_data = th->data(); - T* tweight_data = tweight->data(); - T* tconf_data = tconf->data(); - T* tclass_data = tclass->data(); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < b; j++) { - int box_idx = (i * b + j) * 4; - if (isZero(gt_box_data[box_idx + 2]) && - isZero(gt_box_data[box_idx + 3])) { - continue; - } - - int cur_label = gt_label_data[i * b + j]; - T gx = gt_box_data[box_idx] * grid_size; - T gy = gt_box_data[box_idx + 1] * grid_size; - T gw = gt_box_data[box_idx + 2] * input_size; - T gh = gt_box_data[box_idx + 3] * input_size; - int gi = static_cast(gx); - int gj = static_cast(gy); - - T max_iou = static_cast(0); - T iou; - int best_an_index = -1; - std::vector gt_box_shape({0, 0, gw, gh}); - for (int an_idx = 0; an_idx < an_num; an_idx++) { - std::vector anchor_shape({0, 0, static_cast(anchors[2 * an_idx]), - static_cast(anchors[2 * an_idx + 1])}); - iou = CalcBoxIoU(gt_box_shape, anchor_shape); - if (iou > max_iou) { - max_iou = iou; - best_an_index = an_idx; - } - if (iou > ignore_thresh) { - int conf_idx = ((i * an_num + an_idx) * h + gj) * w + gi; - conf_mask_data[conf_idx] = static_cast(0.0); - } - } - - int obj_idx = ((i * an_num + best_an_index) * h + gj) * w + gi; - conf_mask_data[obj_idx] = static_cast(1.0); - obj_mask_data[obj_idx] = static_cast(1.0); - tx_data[obj_idx] = gx - gi; - ty_data[obj_idx] = gy - gj; - tw_data[obj_idx] = log(gw / anchors[2 * best_an_index]); - th_data[obj_idx] = log(gh / anchors[2 * best_an_index + 1]); - tweight_data[obj_idx] = - 2.0 - gt_box_data[box_idx + 2] * gt_box_data[box_idx + 3]; - tconf_data[obj_idx] = static_cast(1.0); - tclass_data[obj_idx * class_num + cur_label] = static_cast(1.0); - } - } +static inline bool LessEqualZero(T x) { + return x < 1e-6; } template @@ -152,177 +51,8 @@ static T L1LossGrad(T x, T y) { return x > y ? 1.0 : -1.0; } -template -static void CalcSCE(T* loss_data, const T* input, const T* target, - const T* weight, const T* mask, const int n, - const int an_num, const int grid_num, const int class_num, - const int num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - int sub_idx = k * num; - for (int l = 0; l < num; l++) { - loss_data[i] += SCE(input[l * grid_num + k], target[sub_idx + l]) * - weight[k] * mask[k]; - } - } - input += (class_num + 5) * grid_num; - target += grid_num * num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcSCEGrad(T* input_grad, const T* loss_grad, const T* input, - const T* target, const T* weight, const T* mask, - const int n, const int an_num, const int grid_num, - const int class_num, const int num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - int sub_idx = k * num; - for (int l = 0; l < num; l++) { - input_grad[l * grid_num + k] = - SCEGrad(input[l * grid_num + k], target[sub_idx + l]) * - weight[k] * mask[k] * loss_grad[i]; - } - } - input_grad += (class_num + 5) * grid_num; - input += (class_num + 5) * grid_num; - target += grid_num * num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcL1Loss(T* loss_data, const T* input, const T* target, - const T* weight, const T* mask, const int n, - const int an_num, const int grid_num, - const int class_num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - loss_data[i] += L1Loss(input[k], target[k]) * weight[k] * mask[k]; - } - input += (class_num + 5) * grid_num; - target += grid_num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcL1LossGrad(T* input_grad, const T* loss_grad, const T* input, - const T* target, const T* weight, const T* mask, - const int n, const int an_num, const int grid_num, - const int class_num) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < grid_num; k++) { - input_grad[k] = L1LossGrad(input[k], target[k]) * weight[k] * - mask[k] * loss_grad[i]; - } - input_grad += (class_num + 5) * grid_num; - input += (class_num + 5) * grid_num; - target += grid_num; - weight += grid_num; - mask += grid_num; - } - } -} - -template -static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx, - const Tensor& ty, const Tensor& tw, const Tensor& th, - const Tensor& tweight, const Tensor& tconf, - const Tensor& tclass, const Tensor& conf_mask, - const Tensor& obj_mask) { - const T* input_data = input.data(); - const T* tx_data = tx.data(); - const T* ty_data = ty.data(); - const T* tw_data = tw.data(); - const T* th_data = th.data(); - const T* tweight_data = tweight.data(); - const T* tconf_data = tconf.data(); - const T* tclass_data = tclass.data(); - const T* conf_mask_data = conf_mask.data(); - const T* obj_mask_data = obj_mask.data(); - - const int n = tclass.dims()[0]; - const int an_num = tclass.dims()[1]; - const int h = tclass.dims()[2]; - const int w = tclass.dims()[3]; - const int class_num = tclass.dims()[4]; - const int grid_num = h * w; - - CalcSCE(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n, - an_num, grid_num, class_num, 1); - CalcSCE(loss_data, input_data + grid_num, ty_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num, 1); - CalcL1Loss(loss_data, input_data + 2 * grid_num, tw_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcL1Loss(loss_data, input_data + 3 * grid_num, th_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcSCE(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data, - conf_mask_data, n, an_num, grid_num, class_num, 1); - CalcSCE(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data, - obj_mask_data, n, an_num, grid_num, class_num, class_num); -} - -template -static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad, - const Tensor& input, const Tensor& tx, - const Tensor& ty, const Tensor& tw, - const Tensor& th, const Tensor& tweight, - const Tensor& tconf, const Tensor& tclass, - const Tensor& conf_mask, - const Tensor& obj_mask) { - const T* loss_grad_data = loss_grad.data(); - const T* input_data = input.data(); - const T* tx_data = tx.data(); - const T* ty_data = ty.data(); - const T* tw_data = tw.data(); - const T* th_data = th.data(); - const T* tweight_data = tweight.data(); - const T* tconf_data = tconf.data(); - const T* tclass_data = tclass.data(); - const T* conf_mask_data = conf_mask.data(); - const T* obj_mask_data = obj_mask.data(); - - const int n = tclass.dims()[0]; - const int an_num = tclass.dims()[1]; - const int h = tclass.dims()[2]; - const int w = tclass.dims()[3]; - const int class_num = tclass.dims()[4]; - const int grid_num = h * w; - - CalcSCEGrad(input_grad_data, loss_grad_data, input_data, tx_data, - tweight_data, obj_mask_data, n, an_num, grid_num, class_num, - 1); - CalcSCEGrad(input_grad_data + grid_num, loss_grad_data, - input_data + grid_num, ty_data, tweight_data, obj_mask_data, n, - an_num, grid_num, class_num, 1); - CalcL1LossGrad(input_grad_data + 2 * grid_num, loss_grad_data, - input_data + 2 * grid_num, tw_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcL1LossGrad(input_grad_data + 3 * grid_num, loss_grad_data, - input_data + 3 * grid_num, th_data, tweight_data, - obj_mask_data, n, an_num, grid_num, class_num); - CalcSCEGrad(input_grad_data + 4 * grid_num, loss_grad_data, - input_data + 4 * grid_num, tconf_data, conf_mask_data, - conf_mask_data, n, an_num, grid_num, class_num, 1); - CalcSCEGrad(input_grad_data + 5 * grid_num, loss_grad_data, - input_data + 5 * grid_num, tclass_data, obj_mask_data, - obj_mask_data, n, an_num, grid_num, class_num, class_num); -} - -static int mask_index(std::vector mask, int val) { - for (int i = 0; i < mask.size(); i++) { +static int GetMaskIndex(std::vector mask, int val) { + for (size_t i = 0; i < mask.size(); i++) { if (mask[i] == val) { return i; } @@ -341,16 +71,9 @@ static inline T sigmoid(T x) { } template -static inline void sigmoid_arrray(T* arr, int len) { - for (int i = 0; i < len; i++) { - arr[i] = sigmoid(arr[i]); - } -} - -template -static inline Box get_yolo_box(const T* x, std::vector anchors, int i, - int j, int an_idx, int grid_size, - int input_size, int index, int stride) { +static inline Box GetYoloBox(const T* x, std::vector anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride) { Box b; b.x = (i + sigmoid(x[index])) / grid_size; b.y = (j + sigmoid(x[index + stride])) / grid_size; @@ -360,8 +83,7 @@ static inline Box get_yolo_box(const T* x, std::vector anchors, int i, } template -static inline Box get_gt_box(const T* gt, int batch, int max_boxes, - int idx) { +static inline Box GetGtBox(const T* gt, int batch, int max_boxes, int idx) { Box b; b.x = gt[(batch * max_boxes + idx) * 4]; b.y = gt[(batch * max_boxes + idx) * 4 + 1]; @@ -371,7 +93,7 @@ static inline Box get_gt_box(const T* gt, int batch, int max_boxes, } template -static inline T overlap(T c1, T w1, T c2, T w2) { +static inline T BoxOverlap(T c1, T w1, T c2, T w2) { T l1 = c1 - w1 / 2.0; T l2 = c2 - w2 / 2.0; T left = l1 > l2 ? l1 : l2; @@ -382,16 +104,16 @@ static inline T overlap(T c1, T w1, T c2, T w2) { } template -static inline T box_iou(Box b1, Box b2) { - T w = overlap(b1.x, b1.w, b2.x, b2.w); - T h = overlap(b1.y, b1.h, b2.y, b2.h); +static inline T CalcBoxIoU(Box b1, Box b2) { + T w = BoxOverlap(b1.x, b1.w, b2.x, b2.w); + T h = BoxOverlap(b1.y, b1.h, b2.y, b2.h); T inter_area = (w < 0 || h < 0) ? 0.0 : w * h; T union_area = b1.w * b1.h + b2.w * b2.h - inter_area; return inter_area / union_area; } -static inline int entry_index(int batch, int an_idx, int hw_idx, int an_num, - int an_stride, int stride, int entry) { +static inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride, int entry) { return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; } @@ -523,7 +245,7 @@ class Yolov3LossKernel : public framework::OpKernel { const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); - memset(loss_data, 0, n * sizeof(int)); + memset(loss_data, 0, loss->numel() * sizeof(T)); Tensor objness; int* objness_data = @@ -538,22 +260,18 @@ class Yolov3LossKernel : public framework::OpKernel { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { int box_idx = - entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = - get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, - input_size, box_idx, stride); + GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], + h, input_size, box_idx, stride); T best_iou = 0; - // int best_t = 0; for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); - T iou = box_iou(pred, gt); + T iou = CalcBoxIoU(pred, gt); if (iou > best_iou) { best_iou = iou; - // best_t = t; } } @@ -565,11 +283,10 @@ class Yolov3LossKernel : public framework::OpKernel { } } for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); Box gt_shift = gt; @@ -583,7 +300,7 @@ class Yolov3LossKernel : public framework::OpKernel { an_box.y = 0.0; an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); - float iou = box_iou(an_box, gt_shift); + float iou = CalcBoxIoU(an_box, gt_shift); // TO DO: iou > 0.5 ? if (iou > best_iou) { best_iou = iou; @@ -591,10 +308,10 @@ class Yolov3LossKernel : public framework::OpKernel { } } - int mask_idx = mask_index(anchor_mask, best_n); + int mask_idx = GetMaskIndex(anchor_mask, best_n); if (mask_idx >= 0) { - int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 0); + int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, box_idx, gi, gj, h, input_size, stride); @@ -602,8 +319,8 @@ class Yolov3LossKernel : public framework::OpKernel { objness_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; - int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 5); + int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, class_num, stride); } @@ -612,52 +329,6 @@ class Yolov3LossKernel : public framework::OpKernel { CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, mask_num, h, w, stride, an_stride); - - // Tensor conf_mask, obj_mask; - // Tensor tx, ty, tw, th, tweight, tconf, tclass; - // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - // - // math::SetConstant constant; - // constant(ctx.template device_context(), - // &conf_mask, static_cast(1.0)); - // constant(ctx.template device_context(), - // &obj_mask, static_cast(0.0)); - // constant(ctx.template device_context(), &tx, - // static_cast(0.0)); - // constant(ctx.template device_context(), &ty, - // static_cast(0.0)); - // constant(ctx.template device_context(), &tw, - // static_cast(0.0)); - // constant(ctx.template device_context(), &th, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tweight, static_cast(0.0)); - // constant(ctx.template device_context(), - // &tconf, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tclass, - // static_cast(0.0)); - // - // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, - // input_size, - // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, - // &tweight, - // &tconf, &tclass); - // - // T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); - // memset(loss_data, 0, n * sizeof(T)); - // CalcYolov3Loss(loss_data, *input, tx, ty, tw, th, tweight, tconf, - // tclass, - // conf_mask, obj_mask); } }; @@ -706,22 +377,18 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { int box_idx = - entry_index(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = - get_yolo_box(input_data, anchors, l, k, anchor_mask[j], h, - input_size, box_idx, stride); + GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); + Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], + h, input_size, box_idx, stride); T best_iou = 0; - // int best_t = 0; for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); - T iou = box_iou(pred, gt); + T iou = CalcBoxIoU(pred, gt); if (iou > best_iou) { best_iou = iou; - // best_t = t; } } @@ -733,11 +400,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { } } for (int t = 0; t < b; t++) { - if (isZero(gt_box_data[i * b * 4 + t * 4]) && - isZero(gt_box_data[i * b * 4 + t * 4 + 1])) { + Box gt = GetGtBox(gt_box_data, i, b, t); + if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { continue; } - Box gt = get_gt_box(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); Box gt_shift = gt; @@ -751,7 +417,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { an_box.y = 0.0; an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); - float iou = box_iou(an_box, gt_shift); + float iou = CalcBoxIoU(an_box, gt_shift); // TO DO: iou > 0.5 ? if (iou > best_iou) { best_iou = iou; @@ -759,10 +425,10 @@ class Yolov3LossGradKernel : public framework::OpKernel { } } - int mask_idx = mask_index(anchor_mask, best_n); + int mask_idx = GetMaskIndex(anchor_mask, best_n); if (mask_idx >= 0) { - int box_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 0); + int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 0); CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], input_data, gt, anchors, best_n, box_idx, gi, gj, h, input_size, stride); @@ -771,8 +437,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { objness_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; - int label_idx = entry_index(i, mask_idx, gj * w + gi, mask_num, - an_stride, stride, 5); + int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, + an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, label_idx, label, class_num, stride); } @@ -782,58 +448,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, input_data + 4 * stride, objness_data, n, mask_num, h, w, stride, an_stride); - - // const int n = input->dims()[0]; - // const int c = input->dims()[1]; - // const int h = input->dims()[2]; - // const int w = input->dims()[3]; - // const int an_num = anchors.size() / 2; - // - // Tensor conf_mask, obj_mask; - // Tensor tx, ty, tw, th, tweight, tconf, tclass; - // conf_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // obj_mask.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tx.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // ty.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tw.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // th.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tweight.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tconf.mutable_data({n, an_num, h, w}, ctx.GetPlace()); - // tclass.mutable_data({n, an_num, h, w, class_num}, ctx.GetPlace()); - // - // math::SetConstant constant; - // constant(ctx.template device_context(), - // &conf_mask, static_cast(1.0)); - // constant(ctx.template device_context(), - // &obj_mask, static_cast(0.0)); - // constant(ctx.template device_context(), &tx, - // static_cast(0.0)); - // constant(ctx.template device_context(), &ty, - // static_cast(0.0)); - // constant(ctx.template device_context(), &tw, - // static_cast(0.0)); - // constant(ctx.template device_context(), &th, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tweight, static_cast(0.0)); - // constant(ctx.template device_context(), - // &tconf, - // static_cast(0.0)); - // constant(ctx.template device_context(), - // &tclass, - // static_cast(0.0)); - // - // PreProcessGTBox(*gt_box, *gt_label, ignore_thresh, anchors, - // input_size, - // h, &conf_mask, &obj_mask, &tx, &ty, &tw, &th, - // &tweight, - // &tconf, &tclass); - // - // T* input_grad_data = - // input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); - // CalcYolov3LossGrad(input_grad_data, *loss_grad, *input, tx, ty, tw, - // th, - // tweight, tconf, tclass, conf_mask, obj_mask); } }; diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 3cada49647..188acea2b9 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -22,32 +22,6 @@ from op_test import OpTest from paddle.fluid import core -# def l1loss(x, y, weight): -# n = x.shape[0] -# x = x.reshape((n, -1)) -# y = y.reshape((n, -1)) -# weight = weight.reshape((n, -1)) -# return (np.abs(y - x) * weight).sum(axis=1) -# -# -# def mse(x, y, weight): -# n = x.shape[0] -# x = x.reshape((n, -1)) -# y = y.reshape((n, -1)) -# weight = weight.reshape((n, -1)) -# return ((y - x)**2 * weight).sum(axis=1) -# -# -# def sce(x, label, weight): -# n = x.shape[0] -# x = x.reshape((n, -1)) -# label = label.reshape((n, -1)) -# weight = weight.reshape((n, -1)) -# sigmoid_x = expit(x) -# term1 = label * np.log(sigmoid_x) -# term2 = (1.0 - label) * np.log(1.0 - sigmoid_x) -# return ((-term1 - term2) * weight).sum(axis=1) - def l1loss(x, y): return abs(x - y) @@ -60,116 +34,6 @@ def sce(x, label): return -term1 - term2 -def box_iou(box1, box2): - b1_x1 = box1[0] - box1[2] / 2 - b1_x2 = box1[0] + box1[2] / 2 - b1_y1 = box1[1] - box1[3] / 2 - b1_y2 = box1[1] + box1[3] / 2 - b2_x1 = box2[0] - box2[2] / 2 - b2_x2 = box2[0] + box2[2] / 2 - b2_y1 = box2[1] - box2[3] / 2 - b2_y2 = box2[1] + box2[3] / 2 - - b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) - b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) - - inter_rect_x1 = max(b1_x1, b2_x1) - inter_rect_y1 = max(b1_y1, b2_y1) - inter_rect_x2 = min(b1_x2, b2_x2) - inter_rect_y2 = min(b1_y2, b2_y2) - inter_area = max(inter_rect_x2 - inter_rect_x1, 0) * max( - inter_rect_y2 - inter_rect_y1, 0) - - return inter_area / (b1_area + b2_area + inter_area) - - -def build_target(gtboxes, gtlabel, attrs, grid_size): - n, b, _ = gtboxes.shape - ignore_thresh = attrs["ignore_thresh"] - anchors = attrs["anchors"] - class_num = attrs["class_num"] - input_size = attrs["input_size"] - an_num = len(anchors) // 2 - conf_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32') - obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - ty = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tw = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - th = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tweight = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tconf = np.zeros((n, an_num, grid_size, grid_size)).astype('float32') - tcls = np.zeros( - (n, an_num, grid_size, grid_size, class_num)).astype('float32') - - for i in range(n): - for j in range(b): - if gtboxes[i, j, :].sum() == 0: - continue - - gt_label = gtlabel[i, j] - gx = gtboxes[i, j, 0] * grid_size - gy = gtboxes[i, j, 1] * grid_size - gw = gtboxes[i, j, 2] * input_size - gh = gtboxes[i, j, 3] * input_size - - gi = int(gx) - gj = int(gy) - - gtbox = [0, 0, gw, gh] - max_iou = 0 - for k in range(an_num): - anchor_box = [0, 0, anchors[2 * k], anchors[2 * k + 1]] - iou = box_iou(gtbox, anchor_box) - if iou > max_iou: - max_iou = iou - best_an_index = k - if iou > ignore_thresh: - conf_mask[i, best_an_index, gj, gi] = 0 - - conf_mask[i, best_an_index, gj, gi] = 1 - obj_mask[i, best_an_index, gj, gi] = 1 - tx[i, best_an_index, gj, gi] = gx - gi - ty[i, best_an_index, gj, gi] = gy - gj - tw[i, best_an_index, gj, gi] = np.log(gw / anchors[2 * - best_an_index]) - th[i, best_an_index, gj, gi] = np.log( - gh / anchors[2 * best_an_index + 1]) - tweight[i, best_an_index, gj, gi] = 2.0 - gtboxes[ - i, j, 2] * gtboxes[i, j, 3] - tconf[i, best_an_index, gj, gi] = 1 - tcls[i, best_an_index, gj, gi, gt_label] = 1 - - return (tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask) - - -def YoloV3Loss(x, gtbox, gtlabel, attrs): - n, c, h, w = x.shape - an_num = len(attrs['anchors']) // 2 - class_num = attrs["class_num"] - x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) - pred_x = x[:, :, :, :, 0] - pred_y = x[:, :, :, :, 1] - pred_w = x[:, :, :, :, 2] - pred_h = x[:, :, :, :, 3] - pred_conf = x[:, :, :, :, 4] - pred_cls = x[:, :, :, :, 5:] - - tx, ty, tw, th, tweight, tconf, tcls, conf_mask, obj_mask = build_target( - gtbox, gtlabel, attrs, x.shape[2]) - - obj_weight = obj_mask * tweight - obj_mask_expand = np.tile( - np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num']))) - loss_x = sce(pred_x, tx, obj_weight) - loss_y = sce(pred_y, ty, obj_weight) - loss_w = l1loss(pred_w, tw, obj_weight) - loss_h = l1loss(pred_h, th, obj_weight) - loss_obj = sce(pred_conf, tconf, conf_mask) - loss_class = sce(pred_cls, tcls, obj_mask_expand) - - return loss_x + loss_y + loss_w + loss_h + loss_obj + loss_class - - def sigmoid(x): return 1.0 / (1.0 + np.exp(-1.0 * x)) @@ -291,8 +155,10 @@ class TestYolov3LossOp(OpTest): self.op_type = 'yolov3_loss' x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') - gtlabel = np.random.randint(0, self.class_num, - self.gtbox_shape[:2]).astype('int32') + gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]) + gtmask = np.random.randint(0, 2, self.gtbox_shape[:2]) + gtbox = gtbox * gtmask[:, :, np.newaxis] + gtlabel = gtlabel * gtmask self.attrs = { "anchors": self.anchors, @@ -302,7 +168,11 @@ class TestYolov3LossOp(OpTest): "downsample": self.downsample, } - self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel} + self.inputs = { + 'X': x, + 'GTBox': gtbox.astype('float32'), + 'GTLabel': gtlabel.astype('int32') + } self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} def test_check_output(self): From 32d533c2cd9aa6dcd0d3cbe9b9685f97d378337e Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 28 Dec 2018 17:49:02 +0800 Subject: [PATCH 30/53] cache obj_mask and gt_match_mask. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 23 ++++ paddle/fluid/operators/yolov3_loss_op.h | 110 +++++------------- python/paddle/fluid/layers/detection.py | 9 +- .../tests/unittests/test_yolov3_loss_op.py | 16 ++- 4 files changed, 76 insertions(+), 82 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 8c46e341d6..5b777f0448 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -29,6 +29,11 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTLabel) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("ObjectnessMask"), + "Output(ObjectnessMask) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("GTMatchMask"), + "Output(GTMatchMask) of Yolov3LossOp should not be null."); auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); @@ -68,6 +73,12 @@ class Yolov3LossOp : public framework::OperatorWithKernel { std::vector dim_out({dim_x[0]}); ctx->SetOutputDim("Loss", framework::make_ddim(dim_out)); + + std::vector dim_obj_mask({dim_x[0], mask_num, dim_x[2], dim_x[3]}); + ctx->SetOutputDim("ObjectnessMask", framework::make_ddim(dim_obj_mask)); + + std::vector dim_gt_match_mask({dim_gtbox[0], dim_gtbox[1]}); + ctx->SetOutputDim("GTMatchMask", framework::make_ddim(dim_gt_match_mask)); } protected: @@ -103,6 +114,16 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); + AddOutput("ObjectnessMask", + "This is an intermediate tensor with shape of [N, M, H, W], " + "M is the number of anchor masks. This parameter caches the " + "mask for calculate objectness loss in gradient kernel.") + .AsIntermediate(); + AddOutput("GTMatchMask", + "This is an intermediate tensor with shape if [N, B], " + "B is the max box number of GT boxes. This parameter caches " + "matched mask index of each GT boxes for gradient calculate.") + .AsIntermediate(); AddAttr("class_num", "The number of classes to predict."); AddAttr>("anchors", @@ -208,6 +229,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + op->SetInput("ObjectnessMask", Output("ObjectnessMask")); + op->SetInput("GTMatchMask", Output("GTMatchMask")); op->SetAttrMap(Attrs()); diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 12499befca..85d93cf96f 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -227,6 +227,8 @@ class Yolov3LossKernel : public framework::OpKernel { auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); auto* loss = ctx.Output("Loss"); + auto* objness_mask = ctx.Output("ObjectnessMask"); + auto* gt_match_mask = ctx.Output("GTMatchMask"); auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); @@ -241,19 +243,19 @@ class Yolov3LossKernel : public framework::OpKernel { const int b = gt_box->dims()[1]; int input_size = downsample * h; + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); - - Tensor objness; - int* objness_data = - objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(objness_data, 0, objness.numel() * sizeof(int)); - - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; + int* obj_mask_data = + objness_mask->mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(obj_mask_data, 0, objness_mask->numel() * sizeof(int)); + int* gt_match_mask_data = + gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); for (int i = 0; i < n; i++) { for (int j = 0; j < mask_num; j++) { @@ -277,7 +279,7 @@ class Yolov3LossKernel : public framework::OpKernel { if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; - objness_data[obj_idx] = -1; + obj_mask_data[obj_idx] = -1; } } } @@ -285,6 +287,7 @@ class Yolov3LossKernel : public framework::OpKernel { for (int t = 0; t < b; t++) { Box gt = GetGtBox(gt_box_data, i, b, t); if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + gt_match_mask_data[i * b + t] = -1; continue; } int gi = static_cast(gt.x * w); @@ -309,6 +312,7 @@ class Yolov3LossKernel : public framework::OpKernel { } int mask_idx = GetMaskIndex(anchor_mask, best_n); + gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); @@ -316,7 +320,7 @@ class Yolov3LossKernel : public framework::OpKernel { box_idx, gi, gj, h, input_size, stride); int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - objness_data[obj_idx] = 1; + obj_mask_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, @@ -327,7 +331,7 @@ class Yolov3LossKernel : public framework::OpKernel { } } - CalcObjnessLoss(loss_data, input_data + 4 * stride, objness_data, n, + CalcObjnessLoss(loss_data, input_data + 4 * stride, obj_mask_data, n, mask_num, h, w, stride, an_stride); } }; @@ -341,64 +345,35 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* gt_label = ctx.Input("GTLabel"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + auto* objness_mask = ctx.Input("ObjectnessMask"); + auto* gt_match_mask = ctx.Input("GTMatchMask"); auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - const int an_num = anchors.size() / 2; + const int n = input_grad->dims()[0]; + const int c = input_grad->dims()[1]; + const int h = input_grad->dims()[2]; + const int w = input_grad->dims()[3]; const int mask_num = anchor_mask.size(); - const int b = gt_box->dims()[1]; + const int b = gt_match_mask->dims()[1]; int input_size = downsample * h; + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); const T* loss_grad_data = loss_grad->data(); + const int* obj_mask_data = objness_mask->data(); + const int* gt_match_mask_data = gt_match_mask->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); - Tensor objness; - int* objness_data = - objness.mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(objness_data, 0, objness.numel() * sizeof(int)); - - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; - for (int i = 0; i < n; i++) { - for (int j = 0; j < mask_num; j++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - int box_idx = - GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], - h, input_size, box_idx, stride); - T best_iou = 0; - for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { - continue; - } - T iou = CalcBoxIoU(pred, gt); - if (iou > best_iou) { - best_iou = iou; - } - } - - if (best_iou > ignore_thresh) { - int obj_idx = (i * mask_num + j) * stride + k * w + l; - objness_data[obj_idx] = -1; - } - } - } - } for (int t = 0; t < b; t++) { Box gt = GetGtBox(gt_box_data, i, b, t); if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { @@ -406,35 +381,14 @@ class Yolov3LossGradKernel : public framework::OpKernel { } int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); - Box gt_shift = gt; - gt_shift.x = 0.0; - gt_shift.y = 0.0; - T best_iou = 0.0; - int best_n = 0; - for (int an_idx = 0; an_idx < an_num; an_idx++) { - Box an_box; - an_box.x = 0.0; - an_box.y = 0.0; - an_box.w = anchors[2 * an_idx] / static_cast(input_size); - an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); - float iou = CalcBoxIoU(an_box, gt_shift); - // TO DO: iou > 0.5 ? - if (iou > best_iou) { - best_iou = iou; - best_n = an_idx; - } - } - int mask_idx = GetMaskIndex(anchor_mask, best_n); + int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], - input_data, gt, anchors, best_n, box_idx, - gi, gj, h, input_size, stride); - - int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - objness_data[obj_idx] = 1; + CalcBoxLocationLossGrad( + input_grad_data, loss_grad_data[i], input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, @@ -446,7 +400,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { } CalcObjnessLossGrad(input_grad_data + 4 * stride, loss_grad_data, - input_data + 4 * stride, objness_data, n, mask_num, + input_data + 4 * stride, obj_mask_data, n, mask_num, h, w, stride, an_stride); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 542162b7f4..90d112aa01 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -483,6 +483,9 @@ def yolov3_loss(x, loss = helper.create_variable( name=name, dtype=x.dtype, persistable=False) + objectness_mask = helper.create_variable_for_type_inference(dtype='int32') + gt_match_mask = helper.create_variable_for_type_inference(dtype='int32') + attrs = { "anchors": anchors, "anchor_mask": anchor_mask, @@ -496,7 +499,11 @@ def yolov3_loss(x, inputs={"X": x, "GTBox": gtbox, "GTLabel": gtlabel}, - outputs={'Loss': loss}, + outputs={ + 'Loss': loss, + 'ObjectnessMask': objectness_mask, + 'GTMatchMask': gt_match_mask + }, attrs=attrs) return loss diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 188acea2b9..904bee00c1 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -116,13 +116,17 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): anchor_boxes = np.tile(anchor_boxes[np.newaxis, :, :], (n, 1, 1)) ious = batch_xywh_box_iou(gtbox_shift, anchor_boxes) iou_matches = np.argmax(ious, axis=-1) + gt_matches = iou_matches.copy() for i in range(n): for j in range(b): if gtbox[i, j, 2:].sum() == 0: + gt_matches[i, j] = -1 continue if iou_matches[i, j] not in anchor_mask: + gt_matches[i, j] = -1 continue an_idx = anchor_mask.index(iou_matches[i, j]) + gt_matches[i, j] = an_idx gi = int(gtbox[i, j, 0] * w) gj = int(gtbox[i, j, 1] * h) @@ -146,7 +150,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): if objness[i, j] >= 0: loss[i] += sce(pred_obj[i, j], objness[i, j]) - return loss + return (loss, objness.reshape((n, mask_num, h, w)).astype('int32'), \ + gt_matches.astype('int32')) class TestYolov3LossOp(OpTest): @@ -173,11 +178,16 @@ class TestYolov3LossOp(OpTest): 'GTBox': gtbox.astype('float32'), 'GTLabel': gtlabel.astype('int32') } - self.outputs = {'Loss': YOLOv3Loss(x, gtbox, gtlabel, self.attrs)} + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) + self.outputs = { + 'Loss': loss, + 'ObjectnessMask': objness, + "GTMatchMask": gt_matches + } def test_check_output(self): place = core.CPUPlace() - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=2e-3) def test_check_grad_ignore_gtbox(self): place = core.CPUPlace() From cc01db6029c84b5e059d355b95dd73d18894594f Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 28 Dec 2018 20:06:52 +0800 Subject: [PATCH 31/53] calc valid gt before loss calc. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 41 ++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 85d93cf96f..301e2f4033 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -219,6 +219,22 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, } } +template +static void inline GtValid(bool* valid, const T* gtbox, const int n, + const int b) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < b; j++) { + if (LessEqualZero(gtbox[j * 4 + 2]) || LessEqualZero(gtbox[j * 4 + 3])) { + valid[j] = false; + } else { + valid[j] = true; + } + } + valid += b; + gtbox += b * 4; + } +} + template class Yolov3LossKernel : public framework::OpKernel { public: @@ -257,20 +273,28 @@ class Yolov3LossKernel : public framework::OpKernel { int* gt_match_mask_data = gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); + // calc valid gt box mask, avoid calc duplicately in following code + Tensor gt_valid_mask; + bool* gt_valid_mask_data = + gt_valid_mask.mutable_data({n, b}, ctx.GetPlace()); + GtValid(gt_valid_mask_data, gt_box_data, n, b); + for (int i = 0; i < n; i++) { for (int j = 0; j < mask_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { + // each predict box find a best match gt box, if overlap is bigger + // then ignore_thresh, ignore the objectness loss. int box_idx = GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], h, input_size, box_idx, stride); T best_iou = 0; for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + if (!gt_valid_mask_data[i * b + t]) { continue; } + Box gt = GetGtBox(gt_box_data, i, b, t); T iou = CalcBoxIoU(pred, gt); if (iou > best_iou) { best_iou = iou; @@ -281,15 +305,18 @@ class Yolov3LossKernel : public framework::OpKernel { int obj_idx = (i * mask_num + j) * stride + k * w + l; obj_mask_data[obj_idx] = -1; } + // TODO(dengkaipeng): all losses should be calculated if best IoU + // is bigger then truth thresh should be calculated here, but + // currently, truth thresh is an unreachable value as 1.0. } } } for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { + if (!gt_valid_mask_data[i * b + t]) { gt_match_mask_data[i * b + t] = -1; continue; } + Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); Box gt_shift = gt; @@ -297,6 +324,9 @@ class Yolov3LossKernel : public framework::OpKernel { gt_shift.y = 0.0; T best_iou = 0.0; int best_n = 0; + // each gt box find a best match anchor box as positive sample, + // for positive sample, all losses should be calculated, and for + // other samples, only objectness loss is required. for (int an_idx = 0; an_idx < an_num; an_idx++) { Box an_box; an_box.x = 0.0; @@ -304,7 +334,8 @@ class Yolov3LossKernel : public framework::OpKernel { an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); float iou = CalcBoxIoU(an_box, gt_shift); - // TO DO: iou > 0.5 ? + // TODO(dengkaipeng): In paper, objectness loss is ignore when + // best IoU > 0.5, but darknet code didn't implement this. if (iou > best_iou) { best_iou = iou; best_n = an_idx; From 3c08f620c248c506116dbb5a58224de9743bb048 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 3 Jan 2019 11:16:29 +0800 Subject: [PATCH 32/53] add label smooth. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 19 ++++++++++--------- .../tests/unittests/test_yolov3_loss_op.py | 6 +++++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 301e2f4033..34119b1a02 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -159,7 +159,9 @@ static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { - loss[0] += SCE(input[index + i * stride], (i == label) ? 1.0 : 0.0); + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); } } @@ -169,8 +171,10 @@ static inline void CalcLabelLossGrad(T* input_grad, const T loss, const int label, const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; input_grad[index + i * stride] = - SCEGrad(input[index + i * stride], (i == label) ? 1.0 : 0.0) * loss; + SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; } } @@ -406,15 +410,12 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int i = 0; i < n; i++) { for (int t = 0; t < b; t++) { - Box gt = GetGtBox(gt_box_data, i, b, t); - if (LessEqualZero(gt.w) || LessEqualZero(gt.h)) { - continue; - } - int gi = static_cast(gt.x * w); - int gj = static_cast(gt.y * h); - int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { + Box gt = GetGtBox(gt_box_data, i, b, t); + int gi = static_cast(gt.x * w); + int gj = static_cast(gt.y * h); + int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); CalcBoxLocationLossGrad( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 904bee00c1..27fb92c589 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -86,6 +86,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:], + np.ones_like(x[:, :, :, :, 5:]) * 1.0 / + class_num) + mask_anchors = [] for m in anchor_mask: mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) @@ -207,7 +211,7 @@ class TestYolov3LossOp(OpTest): self.ignore_thresh = 0.7 self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 10, 4) + self.gtbox_shape = (3, 5, 4) if __name__ == "__main__": From 8218e30176c6bdaccd11cd0141c6f47878233b54 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 4 Jan 2019 11:40:08 +0800 Subject: [PATCH 33/53] add gtscore. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 20 +++++++++++++++-- paddle/fluid/operators/yolov3_loss_op.h | 22 ++++++++++++------- python/paddle/fluid/layers/detection.py | 17 ++++++++++---- .../tests/unittests/test_yolov3_loss_op.py | 19 +++++++++------- 5 files changed, 57 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6c6ac9c7ea..bf0916a076 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 5b777f0448..c146035f9d 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -27,6 +27,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("GTLabel"), "Input(GTLabel) of Yolov3LossOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("GTScore"), + "Input(GTScore) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); PADDLE_ENFORCE( @@ -38,6 +40,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtlabel = ctx->GetInputDim("GTLabel"); + auto dim_gtscore = ctx->GetInputDim("GTScore"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); @@ -54,11 +57,17 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) should be a 3-D tensor"); PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5"); PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2, - "Input(GTBox) should be a 2-D tensor"); + "Input(GTLabel) should be a 2-D tensor"); PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0], "Input(GTBox) and Input(GTLabel) dim[0] should be same"); PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], "Input(GTBox) and Input(GTLabel) dim[1] should be same"); + PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2, + "Input(GTScore) should be a 2-D tensor"); + PADDLE_ENFORCE_EQ(dim_gtscore[0], dim_gtbox[0], + "Input(GTBox) and Input(GTScore) dim[0] should be same"); + PADDLE_ENFORCE_EQ(dim_gtscore[1], dim_gtbox[1], + "Input(GTBox) and Input(GTScore) dim[1] should be same"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -109,8 +118,13 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("GTLabel", "The input tensor of ground truth label, " "This is a 2-D tensor with shape of [N, max_box_num], " - "and each element shoudl be an integer to indicate the " + "and each element should be an integer to indicate the " "box class id."); + AddInput("GTScore", + "The score of GTLabel, This is a 2-D tensor in same shape " + "GTLabel, and score values should in range (0, 1). This " + "input is for GTLabel score can be not 1.0 in image mixup " + "augmentation."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); @@ -228,6 +242,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("X", Input("X")); op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); + op->SetInput("GTScore", Input("GTScore")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("GTMatchMask", Output("GTMatchMask")); @@ -237,6 +252,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {}); + op->SetOutput(framework::GradVarName("GTScore"), {}); return std::unique_ptr(op); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 34119b1a02..c4095b8ca5 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -156,25 +156,25 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, - const int label, const int class_num, - const int stride) { + const int label, const T score, + const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); + loss[0] += SCE(pred, (i == label) ? score : 0.0); } } template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, - const int label, const int class_num, - const int stride) { + const int label, const T score, + const int class_num, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] : 1.0 / class_num; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; + SCEGrad(pred, (i == label) ? score : 0.0) * loss; } } @@ -246,6 +246,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* loss = ctx.Output("Loss"); auto* objness_mask = ctx.Output("ObjectnessMask"); auto* gt_match_mask = ctx.Output("GTMatchMask"); @@ -269,6 +270,7 @@ class Yolov3LossKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); int* obj_mask_data = @@ -358,9 +360,10 @@ class Yolov3LossKernel : public framework::OpKernel { obj_mask_data[obj_idx] = 1; int label = gt_label_data[i * b + t]; + T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); - CalcLabelLoss(loss_data + i, input_data, label_idx, label, + CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, class_num, stride); } } @@ -378,6 +381,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -401,6 +405,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); const int* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); @@ -423,10 +428,11 @@ class Yolov3LossGradKernel : public framework::OpKernel { anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; + T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, class_num, stride); + label_idx, label, score, class_num, stride); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 90d112aa01..10573cc4c6 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -412,6 +412,7 @@ def polygon_box_transform(input, name=None): def yolov3_loss(x, gtbox, gtlabel, + gtscore, anchors, anchor_mask, class_num, @@ -428,8 +429,10 @@ def yolov3_loss(x, and x, y, w, h should be relative value of input image. N is the batch number and B is the max box number in an image. - gtlabel (Variable): class id of ground truth boxes, shoud be ins shape + gtlabel (Variable): class id of ground truth boxes, shoud be in shape of [N, B]. + gtscore (Variable): score of gtlabel, should be in same shape with gtlabel + and score value in range (0, 1). anchors (list|tuple): ${anchors_comment} anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} @@ -444,6 +447,7 @@ def yolov3_loss(x, TypeError: Input x of yolov3_loss must be Variable TypeError: Input gtbox of yolov3_loss must be Variable" TypeError: Input gtlabel of yolov3_loss must be Variable" + TypeError: Input gtscore of yolov3_loss must be Variable" TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number @@ -467,6 +471,8 @@ def yolov3_loss(x, raise TypeError("Input gtbox of yolov3_loss must be Variable") if not isinstance(gtlabel, Variable): raise TypeError("Input gtlabel of yolov3_loss must be Variable") + if not isinstance(gtscore, Variable): + raise TypeError("Input gtscore of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): @@ -496,9 +502,12 @@ def yolov3_loss(x, helper.append_op( type='yolov3_loss', - inputs={"X": x, - "GTBox": gtbox, - "GTLabel": gtlabel}, + inputs={ + "X": x, + "GTBox": gtbox, + "GTLabel": gtlabel, + "GTScore": gtscore + }, outputs={ 'Loss': loss, 'ObjectnessMask': objectness_mask, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 27fb92c589..c65570d7c1 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2): return inter_area / union -def YOLOv3Loss(x, gtbox, gtlabel, attrs): +def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): n, c, h, w = x.shape b = gtbox.shape[1] anchors = attrs['anchors'] @@ -148,7 +148,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): for label_idx in range(class_num): loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], - int(label_idx == gtlabel[i, j])) + int(label_idx == gtlabel[i, j]) * gtscore[i, j]) for j in range(mask_num * h * w): if objness[i, j] >= 0: @@ -165,6 +165,7 @@ class TestYolov3LossOp(OpTest): x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]) + gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32') gtmask = np.random.randint(0, 2, self.gtbox_shape[:2]) gtbox = gtbox * gtmask[:, :, np.newaxis] gtlabel = gtlabel * gtmask @@ -180,9 +181,11 @@ class TestYolov3LossOp(OpTest): self.inputs = { 'X': x, 'GTBox': gtbox.astype('float32'), - 'GTLabel': gtlabel.astype('int32') + 'GTLabel': gtlabel.astype('int32'), + 'GTScore': gtscore.astype('float32') } - loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore, + self.attrs) self.outputs = { 'Loss': loss, 'ObjectnessMask': objness, @@ -198,8 +201,8 @@ class TestYolov3LossOp(OpTest): self.check_grad_with_place( place, ['X'], 'Loss', - no_grad_set=set(["GTBox", "GTLabel"]), - max_relative_error=0.15) + no_grad_set=set(["GTBox", "GTLabel", "GTScore"]), + max_relative_error=0.2) def initTestCase(self): self.anchors = [ @@ -207,11 +210,11 @@ class TestYolov3LossOp(OpTest): 373, 326 ] self.anchor_mask = [0, 1, 2] - self.class_num = 5 + self.class_num = 10 self.ignore_thresh = 0.7 self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 5, 4) + self.gtbox_shape = (3, 10, 4) if __name__ == "__main__": From 2b89f590559bc76d6f821789edee42cf56a68582 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 10 Jan 2019 06:57:28 +0000 Subject: [PATCH 34/53] add attr use_label_smooth test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 3 ++ paddle/fluid/operators/yolov3_loss_op.h | 46 +++++++++++++------ python/paddle/fluid/layers/detection.py | 6 +++ .../tests/unittests/test_yolov3_loss_op.py | 8 ++++ 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bf0916a076..d773c2518c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'label_smooth', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index c146035f9d..0c5426728b 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -46,6 +46,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); int mask_num = anchor_mask.size(); auto class_num = ctx->Attrs().Get("class_num"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); PADDLE_ENFORCE_EQ(dim_x[2], dim_x[3], "Input(X) dim[3] and dim[4] should be euqal."); @@ -156,6 +157,8 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss.") .SetDefault(0.7); + AddAttr("use_label_smooth", "bool,default True", "use label smooth") + .SetDefault(true); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index c4095b8ca5..f601651f06 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -157,11 +157,19 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const T score, - const int class_num, const int stride) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? score : 0.0); + const int class_num, const int stride, + const bool use_label_smooth) { + if (use_label_smooth) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + loss[0] += SCE(pred, (i == label) ? score : 0.0); + } + } else { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + loss[0] += SCE(pred, (i == label) ? score : 0.0); + } } } @@ -169,12 +177,21 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const T score, - const int class_num, const int stride) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? score : 0.0) * loss; + const int class_num, const int stride, + const bool use_label_smooth) { + if (use_label_smooth) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] + : 1.0 / class_num; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? score : 0.0) * loss; + } + } else { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? score : 0.0) * loss; + } } } @@ -255,6 +272,7 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -364,7 +382,7 @@ class Yolov3LossKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, - class_num, stride); + class_num, stride, use_label_smooth); } } } @@ -390,6 +408,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); int downsample = ctx.Attr("downsample"); + bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -432,7 +451,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, score, class_num, stride); + label_idx, label, score, class_num, stride, + use_label_smooth); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 10573cc4c6..e984576ffe 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -418,6 +418,7 @@ def yolov3_loss(x, class_num, ignore_thresh, downsample, + use_label_smooth=True, name=None): """ ${comment} @@ -438,6 +439,7 @@ def yolov3_loss(x, class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} downsample (int): ${downsample_comment} + use_label_smooth(bool): ${use_label_smooth_comment} name (string): the name of yolov3 loss Returns: @@ -451,6 +453,7 @@ def yolov3_loss(x, TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number + TypeError: Attr use_label_smooth of yolov3_loss must be a bool value Examples: .. code-block:: python @@ -479,6 +482,8 @@ def yolov3_loss(x, raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(class_num, int): + raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( "Attr ignore_thresh of yolov3_loss must be a float number") @@ -498,6 +503,7 @@ def yolov3_loss(x, "class_num": class_num, "ignore_thresh": ignore_thresh, "downsample": downsample, + "use_label_smooth": use_label_smooth } helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index c65570d7c1..1746a1da1d 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -76,6 +76,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] downsample = attrs['downsample'] + #use_label_smooth = attrs['use_label_smooth'] input_size = downsample * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') @@ -176,6 +177,7 @@ class TestYolov3LossOp(OpTest): "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, "downsample": self.downsample, + "use_label_smooth": self.use_label_smooth, } self.inputs = { @@ -215,6 +217,12 @@ class TestYolov3LossOp(OpTest): self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 10, 4) + self.use_label_smooth = True + + +class TestYolov3LossWithLabelSmooth(TestYolov3LossOp): + def set_label_smooth(self): + self.use_label_smooth = True if __name__ == "__main__": From 20200e126d0bfcc9e98e278764768f38ff1831e8 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Thu, 10 Jan 2019 07:15:35 +0000 Subject: [PATCH 35/53] fix some typo test=develop --- python/paddle/fluid/layers/detection.py | 2 +- python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index e984576ffe..febfc8e127 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -482,7 +482,7 @@ def yolov3_loss(x, raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(class_num, int): + if not isinstance(use_label_smooth, int): raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 1746a1da1d..79c953bbd1 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -76,7 +76,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] downsample = attrs['downsample'] - #use_label_smooth = attrs['use_label_smooth'] + use_label_smooth = attrs['use_label_smooth'] input_size = downsample * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') From c945ffa7f8949277e1053c430918147d9e908303 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 14 Jan 2019 21:16:06 +0800 Subject: [PATCH 36/53] fix label_smooth and mixup score --- paddle/fluid/operators/yolov3_loss_op.h | 98 +++++++++---------- .../tests/unittests/test_yolov3_loss_op.py | 17 ++-- 2 files changed, 55 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index f601651f06..5cb48b7cdf 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -156,47 +156,29 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, - const int label, const T score, - const int class_num, const int stride, - const bool use_label_smooth) { - if (use_label_smooth) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - loss[0] += SCE(pred, (i == label) ? score : 0.0); - } - } else { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? score : 0.0); - } + const int label, const int class_num, + const int stride, const T pos, const T neg) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + loss[0] += SCE(pred, (i == label) ? pos : neg); } } template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, - const int label, const T score, - const int class_num, const int stride, - const bool use_label_smooth) { - if (use_label_smooth) { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride] < -0.5 ? input[index + i * stride] - : 1.0 / class_num; - input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? score : 0.0) * loss; - } - } else { - for (int i = 0; i < class_num; i++) { - T pred = input[index + i * stride]; - input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? score : 0.0) * loss; - } + const int label, const int class_num, + const int stride, const T pos, + const T neg) { + for (int i = 0; i < class_num; i++) { + T pred = input[index + i * stride]; + input_grad[index + i * stride] = + SCEGrad(pred, (i == label) ? pos : neg) * loss; } } template -static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, +static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, const int n, const int an_num, const int h, const int w, const int stride, const int an_stride) { @@ -204,9 +186,9 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - int obj = objness[k * w + l]; - if (obj >= 0) { - loss[i] += SCE(input[k * w + l], static_cast(obj)); + T obj = objness[k * w + l]; + if (obj > -0.5) { + loss[i] += SCE(input[k * w + l], obj); } } } @@ -218,7 +200,7 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const int* objness, template static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, - const T* input, const int* objness, + const T* input, const T* objness, const int n, const int an_num, const int h, const int w, const int stride, const int an_stride) { @@ -226,10 +208,9 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { - int obj = objness[k * w + l]; - if (obj >= 0) { - input_grad[k * w + l] = - SCEGrad(input[k * w + l], static_cast(obj)) * loss[i]; + T obj = objness[k * w + l]; + if (obj > -0.5) { + input_grad[k * w + l] = SCEGrad(input[k * w + l], obj) * loss[i]; } } } @@ -285,15 +266,22 @@ class Yolov3LossKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; + T label_pos = 1.0; + T label_neg = 0.0; + if (use_label_smooth) { + label_pos = 1.0 - 1.0 / static_cast(class_num); + label_neg = 1.0 / static_cast(class_num); + } + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); const T* gt_score_data = gt_score->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); - int* obj_mask_data = - objness_mask->mutable_data({n, mask_num, h, w}, ctx.GetPlace()); - memset(obj_mask_data, 0, objness_mask->numel() * sizeof(int)); + T* obj_mask_data = + objness_mask->mutable_data({n, mask_num, h, w}, ctx.GetPlace()); + memset(obj_mask_data, 0, objness_mask->numel() * sizeof(T)); int* gt_match_mask_data = gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); @@ -327,7 +315,7 @@ class Yolov3LossKernel : public framework::OpKernel { if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; - obj_mask_data[obj_idx] = -1; + obj_mask_data[obj_idx] = static_cast(-1.0); } // TODO(dengkaipeng): all losses should be calculated if best IoU // is bigger then truth thresh should be calculated here, but @@ -374,15 +362,15 @@ class Yolov3LossKernel : public framework::OpKernel { CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, box_idx, gi, gj, h, input_size, stride); + T score = gt_score_data[i * b + t]; int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - obj_mask_data[obj_idx] = 1; + obj_mask_data[obj_idx] = score; int label = gt_label_data[i * b + t]; - T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); - CalcLabelLoss(loss_data + i, input_data, label_idx, label, score, - class_num, stride, use_label_smooth); + CalcLabelLoss(loss_data + i, input_data, label_idx, label, + class_num, stride, label_pos, label_neg); } } } @@ -399,7 +387,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); - auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -421,12 +408,18 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; + T label_pos = 1.0; + T label_neg = 0.0; + if (use_label_smooth) { + label_pos = 1.0 - 1.0 / static_cast(class_num); + label_neg = 1.0 / static_cast(class_num); + } + const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); - const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); - const int* obj_mask_data = objness_mask->data(); + const T* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); T* input_grad_data = input_grad->mutable_data({n, c, h, w}, ctx.GetPlace()); @@ -447,12 +440,11 @@ class Yolov3LossGradKernel : public framework::OpKernel { anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; - T score = gt_score_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, score, class_num, stride, - use_label_smooth); + label_idx, label, class_num, stride, label_pos, + label_neg); } } } diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 79c953bbd1..426a64f7a2 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -81,6 +81,9 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') + label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0 + label_neg = 1.0 / class_num if use_label_smooth else 0.0 + pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) @@ -103,7 +106,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): pred_box = pred_box.reshape((n, -1, 4)) pred_obj = x[:, :, :, :, 4].reshape((n, -1)) - objness = np.zeros(pred_box.shape[:2]) + objness = np.zeros(pred_box.shape[:2]).astype('float32') ious = batch_xywh_box_iou(pred_box, gtbox) ious_max = np.max(ious, axis=-1) objness = np.where(ious_max > ignore_thresh, -np.ones_like(objness), @@ -145,17 +148,17 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale - objness[i, an_idx * h * w + gj * w + gi] = 1 + objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j] for label_idx in range(class_num): - loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], - int(label_idx == gtlabel[i, j]) * gtscore[i, j]) + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos + if label_idx == gtlabel[i, j] else label_neg) for j in range(mask_num * h * w): if objness[i, j] >= 0: loss[i] += sce(pred_obj[i, j], objness[i, j]) - return (loss, objness.reshape((n, mask_num, h, w)).astype('int32'), \ + return (loss, objness.reshape((n, mask_num, h, w)).astype('float32'), \ gt_matches.astype('int32')) @@ -220,9 +223,9 @@ class TestYolov3LossOp(OpTest): self.use_label_smooth = True -class TestYolov3LossWithLabelSmooth(TestYolov3LossOp): +class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): def set_label_smooth(self): - self.use_label_smooth = True + self.use_label_smooth = False if __name__ == "__main__": From af124dcdf6891390202fffb7c30daf70aa3c8659 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 14 Jan 2019 21:30:25 +0800 Subject: [PATCH 37/53] fix API error --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.h | 55 ++++++++++++------- python/paddle/fluid/layers/detection.py | 2 +- .../tests/unittests/test_yolov3_loss_op.py | 11 ++-- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d773c2518c..e71e494f9d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'label_smooth', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(True, None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 5cb48b7cdf..de01a01a4f 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -121,13 +121,13 @@ template static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, int grid_size, - int input_size, int stride) { + int input_size, int stride, T score) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = 2.0 - gt.w * gt.h; + T scale = (2.0 - gt.w * gt.h) * score; loss[0] += SCE(input[box_idx], tx) * scale; loss[0] += SCE(input[box_idx + stride], ty) * scale; loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; @@ -138,13 +138,14 @@ template static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, - int grid_size, int input_size, int stride) { + int grid_size, int input_size, int stride, + T score) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = 2.0 - gt.w * gt.h; + T scale = (2.0 - gt.w * gt.h) * score; input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; input_grad[box_idx + stride] = SCEGrad(input[box_idx + stride], ty) * scale * loss; @@ -157,10 +158,11 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, const T neg) { + const int stride, const T pos, const T neg, + T score) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? pos : neg); + loss[0] += SCE(pred, (i == label) ? pos : neg) * score; } } @@ -168,12 +170,12 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, - const T neg) { + const int stride, const T pos, const T neg, + T score) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? pos : neg) * loss; + SCEGrad(pred, (i == label) ? pos : neg) * score * loss; } } @@ -187,8 +189,12 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; - if (obj > -0.5) { - loss[i] += SCE(input[k * w + l], obj); + if (obj > 1e-5) { + // positive sample: obj = mixup score + loss[i] += SCE(input[k * w + l], 1.0) * obj; + } else if (obj > -0.5) { + // negetive sample: obj = 0 + loss[i] += SCE(input[k * w + l], 0.0); } } } @@ -209,8 +215,11 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; - if (obj > -0.5) { - input_grad[k * w + l] = SCEGrad(input[k * w + l], obj) * loss[i]; + if (obj > 1e-5) { + input_grad[k * w + l] = + SCEGrad(input[k * w + l], 1.0) * obj * loss[i]; + } else if (obj > -0.5) { + input_grad[k * w + l] = SCEGrad(input[k * w + l], 0.0) * loss[i]; } } } @@ -315,7 +324,7 @@ class Yolov3LossKernel : public framework::OpKernel { if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; - obj_mask_data[obj_idx] = static_cast(-1.0); + obj_mask_data[obj_idx] = static_cast(-1); } // TODO(dengkaipeng): all losses should be calculated if best IoU // is bigger then truth thresh should be calculated here, but @@ -357,12 +366,12 @@ class Yolov3LossKernel : public framework::OpKernel { int mask_idx = GetMaskIndex(anchor_mask, best_n); gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { + T score = gt_score_data[i * b + t]; int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, - box_idx, gi, gj, h, input_size, stride); + box_idx, gi, gj, h, input_size, stride, score); - T score = gt_score_data[i * b + t]; int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; obj_mask_data[obj_idx] = score; @@ -370,7 +379,7 @@ class Yolov3LossKernel : public framework::OpKernel { int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, - class_num, stride, label_pos, label_neg); + class_num, stride, label_pos, label_neg, score); } } } @@ -387,6 +396,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); + auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -418,6 +428,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); + const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); const T* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); @@ -429,22 +440,24 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int t = 0; t < b; t++) { int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { + T score = gt_score_data[i * b + t]; Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad( - input_grad_data, loss_grad_data[i], input_data, gt, anchors, - anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); + CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], + input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, + input_size, stride, score); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, label_idx, label, class_num, stride, label_pos, - label_neg); + label_neg, score); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index febfc8e127..07df601697 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -482,7 +482,7 @@ def yolov3_loss(x, raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(use_label_smooth, int): + if not isinstance(use_label_smooth, bool): raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 426a64f7a2..ff76b76366 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -142,7 +142,7 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ty = gtbox[i, j, 1] * w - gj tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) - scale = 2.0 - gtbox[i, j, 2] * gtbox[i, j, 3] + scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j] loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale @@ -152,11 +152,14 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): for label_idx in range(class_num): loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos - if label_idx == gtlabel[i, j] else label_neg) + if label_idx == gtlabel[i, j] else + label_neg) * gtscore[i, j] for j in range(mask_num * h * w): - if objness[i, j] >= 0: - loss[i] += sce(pred_obj[i, j], objness[i, j]) + if objness[i, j] > 0: + loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j] + elif objness[i, j] == 0: + loss[i] += sce(pred_obj[i, j], 0.0) return (loss, objness.reshape((n, mask_num, h, w)).astype('float32'), \ gt_matches.astype('int32')) From 042fecefab41a61fdf5f83913b96a039f75b15c5 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 21 Jan 2019 15:04:26 +0800 Subject: [PATCH 38/53] use L2Loss. test=develop --- paddle/fluid/operators/yolov3_loss_op.h | 18 ++++++++++--- .../tests/unittests/test_yolov3_loss_op.py | 25 ++++++++++--------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index de01a01a4f..2131289860 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -41,6 +41,11 @@ static T L1Loss(T x, T y) { return std::abs(y - x); } +template +static T L2Loss(T x, T y) { + return 0.5 * (y - x) * (y - x); +} + template static T SCEGrad(T x, T label) { return 1.0 / (1.0 + std::exp(-x)) - label; @@ -51,6 +56,11 @@ static T L1LossGrad(T x, T y) { return x > y ? 1.0 : -1.0; } +template +static T L2LossGrad(T x, T y) { + return x - y; +} + static int GetMaskIndex(std::vector mask, int val) { for (size_t i = 0; i < mask.size(); i++) { if (mask[i] == val) { @@ -130,8 +140,8 @@ static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, T scale = (2.0 - gt.w * gt.h) * score; loss[0] += SCE(input[box_idx], tx) * scale; loss[0] += SCE(input[box_idx + stride], ty) * scale; - loss[0] += L1Loss(input[box_idx + 2 * stride], tw) * scale; - loss[0] += L1Loss(input[box_idx + 3 * stride], th) * scale; + loss[0] += L2Loss(input[box_idx + 2 * stride], tw) * scale; + loss[0] += L2Loss(input[box_idx + 3 * stride], th) * scale; } template @@ -150,9 +160,9 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, input_grad[box_idx + stride] = SCEGrad(input[box_idx + stride], ty) * scale * loss; input_grad[box_idx + 2 * stride] = - L1LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; + L2LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; input_grad[box_idx + 3 * stride] = - L1LossGrad(input[box_idx + 3 * stride], th) * scale * loss; + L2LossGrad(input[box_idx + 3 * stride], th) * scale * loss; } template diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index ff76b76366..0e17eb3130 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -27,6 +27,10 @@ def l1loss(x, y): return abs(x - y) +def l2loss(x, y): + return 0.5 * (y - x) * (y - x) + + def sce(x, label): sigmoid_x = expit(x) term1 = label * np.log(sigmoid_x) @@ -145,8 +149,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j] loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale - loss[i] += l1loss(x[i, an_idx, gj, gi, 2], tw) * scale - loss[i] += l1loss(x[i, an_idx, gj, gi, 3], th) * scale + loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale + loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j] @@ -202,7 +206,7 @@ class TestYolov3LossOp(OpTest): def test_check_output(self): place = core.CPUPlace() - self.check_output_with_place(place, atol=2e-3) + self.check_output_with_place(place, atol=1e-3) def test_check_grad_ignore_gtbox(self): place = core.CPUPlace() @@ -210,19 +214,16 @@ class TestYolov3LossOp(OpTest): place, ['X'], 'Loss', no_grad_set=set(["GTBox", "GTLabel", "GTScore"]), - max_relative_error=0.2) + max_relative_error=0.3) def initTestCase(self): - self.anchors = [ - 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, - 373, 326 - ] - self.anchor_mask = [0, 1, 2] - self.class_num = 10 - self.ignore_thresh = 0.7 + self.anchors = [10, 13, 16, 30, 33, 23] + self.anchor_mask = [1, 2] + self.class_num = 5 + self.ignore_thresh = 0.5 self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) - self.gtbox_shape = (3, 10, 4) + self.gtbox_shape = (3, 5, 4) self.use_label_smooth = True From 577424e5ecc47446ee0796794004acf5a5852b19 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 28 Jan 2019 16:53:15 +0800 Subject: [PATCH 39/53] use darknet loss and trick --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 18 ----- paddle/fluid/operators/yolov3_loss_op.h | 72 +++++-------------- python/paddle/fluid/layers/detection.py | 13 ---- .../tests/unittests/test_yolov3_loss_op.py | 35 +++------ 5 files changed, 26 insertions(+), 114 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e71e494f9d..6c6ac9c7ea 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'gtscore', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(True, None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 0c5426728b..46374db49a 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -27,8 +27,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("GTLabel"), "Input(GTLabel) of Yolov3LossOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("GTScore"), - "Input(GTScore) of Yolov3LossOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) of Yolov3LossOp should not be null."); PADDLE_ENFORCE( @@ -40,7 +38,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { auto dim_x = ctx->GetInputDim("X"); auto dim_gtbox = ctx->GetInputDim("GTBox"); auto dim_gtlabel = ctx->GetInputDim("GTLabel"); - auto dim_gtscore = ctx->GetInputDim("GTScore"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto anchor_mask = ctx->Attrs().Get>("anchor_mask"); @@ -63,12 +60,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel { "Input(GTBox) and Input(GTLabel) dim[0] should be same"); PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1], "Input(GTBox) and Input(GTLabel) dim[1] should be same"); - PADDLE_ENFORCE_EQ(dim_gtscore.size(), 2, - "Input(GTScore) should be a 2-D tensor"); - PADDLE_ENFORCE_EQ(dim_gtscore[0], dim_gtbox[0], - "Input(GTBox) and Input(GTScore) dim[0] should be same"); - PADDLE_ENFORCE_EQ(dim_gtscore[1], dim_gtbox[1], - "Input(GTBox) and Input(GTScore) dim[1] should be same"); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -121,11 +112,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "This is a 2-D tensor with shape of [N, max_box_num], " "and each element should be an integer to indicate the " "box class id."); - AddInput("GTScore", - "The score of GTLabel, This is a 2-D tensor in same shape " - "GTLabel, and score values should in range (0, 1). This " - "input is for GTLabel score can be not 1.0 in image mixup " - "augmentation."); AddOutput("Loss", "The output yolov3 loss tensor, " "This is a 1-D tensor with shape of [N]"); @@ -157,8 +143,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ignore_thresh", "The ignore threshold to ignore confidence loss.") .SetDefault(0.7); - AddAttr("use_label_smooth", "bool,default True", "use label smooth") - .SetDefault(true); AddComment(R"DOC( This operator generate yolov3 loss by given predict result and ground truth boxes. @@ -245,7 +229,6 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("X", Input("X")); op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTLabel", Input("GTLabel")); - op->SetInput("GTScore", Input("GTScore")); op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("GTMatchMask", Output("GTMatchMask")); @@ -255,7 +238,6 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {}); - op->SetOutput(framework::GradVarName("GTScore"), {}); return std::unique_ptr(op); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 2131289860..5c9851232d 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -36,11 +36,6 @@ static T SCE(T x, T label) { return (x > 0 ? x : 0.0) - x * label + std::log(1.0 + std::exp(-std::abs(x))); } -template -static T L1Loss(T x, T y) { - return std::abs(y - x); -} - template static T L2Loss(T x, T y) { return 0.5 * (y - x) * (y - x); @@ -51,11 +46,6 @@ static T SCEGrad(T x, T label) { return 1.0 / (1.0 + std::exp(-x)) - label; } -template -static T L1LossGrad(T x, T y) { - return x > y ? 1.0 : -1.0; -} - template static T L2LossGrad(T x, T y) { return x - y; @@ -131,13 +121,13 @@ template static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, int grid_size, - int input_size, int stride, T score) { + int input_size, int stride) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = (2.0 - gt.w * gt.h) * score; + T scale = (2.0 - gt.w * gt.h); loss[0] += SCE(input[box_idx], tx) * scale; loss[0] += SCE(input[box_idx + stride], ty) * scale; loss[0] += L2Loss(input[box_idx + 2 * stride], tw) * scale; @@ -148,14 +138,13 @@ template static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, Box gt, std::vector anchors, int an_idx, int box_idx, int gi, int gj, - int grid_size, int input_size, int stride, - T score) { + int grid_size, int input_size, int stride) { T tx = gt.x * grid_size - gi; T ty = gt.y * grid_size - gj; T tw = std::log(gt.w * input_size / anchors[2 * an_idx]); T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); - T scale = (2.0 - gt.w * gt.h) * score; + T scale = (2.0 - gt.w * gt.h); input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; input_grad[box_idx + stride] = SCEGrad(input[box_idx + stride], ty) * scale * loss; @@ -168,11 +157,10 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, template static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, const T neg, - T score) { + const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? pos : neg) * score; + loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); } } @@ -180,12 +168,11 @@ template static inline void CalcLabelLossGrad(T* input_grad, const T loss, const T* input, const int index, const int label, const int class_num, - const int stride, const T pos, const T neg, - T score) { + const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? pos : neg) * score * loss; + SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; } } @@ -201,7 +188,7 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, T obj = objness[k * w + l]; if (obj > 1e-5) { // positive sample: obj = mixup score - loss[i] += SCE(input[k * w + l], 1.0) * obj; + loss[i] += SCE(input[k * w + l], 1.0); } else if (obj > -0.5) { // negetive sample: obj = 0 loss[i] += SCE(input[k * w + l], 0.0); @@ -226,8 +213,7 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; if (obj > 1e-5) { - input_grad[k * w + l] = - SCEGrad(input[k * w + l], 1.0) * obj * loss[i]; + input_grad[k * w + l] = SCEGrad(input[k * w + l], 1.0) * loss[i]; } else if (obj > -0.5) { input_grad[k * w + l] = SCEGrad(input[k * w + l], 0.0) * loss[i]; } @@ -263,7 +249,6 @@ class Yolov3LossKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); - auto* gt_score = ctx.Input("GTScore"); auto* loss = ctx.Output("Loss"); auto* objness_mask = ctx.Output("ObjectnessMask"); auto* gt_match_mask = ctx.Output("GTMatchMask"); @@ -272,7 +257,6 @@ class Yolov3LossKernel : public framework::OpKernel { int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample = ctx.Attr("downsample"); - bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -285,17 +269,9 @@ class Yolov3LossKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; - T label_pos = 1.0; - T label_neg = 0.0; - if (use_label_smooth) { - label_pos = 1.0 - 1.0 / static_cast(class_num); - label_neg = 1.0 / static_cast(class_num); - } - const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); - const T* gt_score_data = gt_score->data(); T* loss_data = loss->mutable_data({n}, ctx.GetPlace()); memset(loss_data, 0, loss->numel() * sizeof(T)); T* obj_mask_data = @@ -376,20 +352,19 @@ class Yolov3LossKernel : public framework::OpKernel { int mask_idx = GetMaskIndex(anchor_mask, best_n); gt_match_mask_data[i * b + t] = mask_idx; if (mask_idx >= 0) { - T score = gt_score_data[i * b + t]; int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); CalcBoxLocationLoss(loss_data + i, input_data, gt, anchors, best_n, - box_idx, gi, gj, h, input_size, stride, score); + box_idx, gi, gj, h, input_size, stride); int obj_idx = (i * mask_num + mask_idx) * stride + gj * w + gi; - obj_mask_data[obj_idx] = score; + obj_mask_data[obj_idx] = 1.0; int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLoss(loss_data + i, input_data, label_idx, label, - class_num, stride, label_pos, label_neg, score); + class_num, stride); } } } @@ -406,7 +381,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* gt_box = ctx.Input("GTBox"); auto* gt_label = ctx.Input("GTLabel"); - auto* gt_score = ctx.Input("GTScore"); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* objness_mask = ctx.Input("ObjectnessMask"); @@ -415,7 +389,6 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); int downsample = ctx.Attr("downsample"); - bool use_label_smooth = ctx.Attr("use_label_smooth"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -428,17 +401,9 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; - T label_pos = 1.0; - T label_neg = 0.0; - if (use_label_smooth) { - label_pos = 1.0 - 1.0 / static_cast(class_num); - label_neg = 1.0 / static_cast(class_num); - } - const T* input_data = input->data(); const T* gt_box_data = gt_box->data(); const int* gt_label_data = gt_label->data(); - const T* gt_score_data = gt_score->data(); const T* loss_grad_data = loss_grad->data(); const T* obj_mask_data = objness_mask->data(); const int* gt_match_mask_data = gt_match_mask->data(); @@ -450,24 +415,21 @@ class Yolov3LossGradKernel : public framework::OpKernel { for (int t = 0; t < b; t++) { int mask_idx = gt_match_mask_data[i * b + t]; if (mask_idx >= 0) { - T score = gt_score_data[i * b + t]; Box gt = GetGtBox(gt_box_data, i, b, t); int gi = static_cast(gt.x * w); int gj = static_cast(gt.y * h); int box_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 0); - CalcBoxLocationLossGrad(input_grad_data, loss_grad_data[i], - input_data, gt, anchors, - anchor_mask[mask_idx], box_idx, gi, gj, h, - input_size, stride, score); + CalcBoxLocationLossGrad( + input_grad_data, loss_grad_data[i], input_data, gt, anchors, + anchor_mask[mask_idx], box_idx, gi, gj, h, input_size, stride); int label = gt_label_data[i * b + t]; int label_idx = GetEntryIndex(i, mask_idx, gj * w + gi, mask_num, an_stride, stride, 5); CalcLabelLossGrad(input_grad_data, loss_grad_data[i], input_data, - label_idx, label, class_num, stride, label_pos, - label_neg, score); + label_idx, label, class_num, stride); } } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 07df601697..ea130bb279 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -412,13 +412,11 @@ def polygon_box_transform(input, name=None): def yolov3_loss(x, gtbox, gtlabel, - gtscore, anchors, anchor_mask, class_num, ignore_thresh, downsample, - use_label_smooth=True, name=None): """ ${comment} @@ -432,14 +430,11 @@ def yolov3_loss(x, an image. gtlabel (Variable): class id of ground truth boxes, shoud be in shape of [N, B]. - gtscore (Variable): score of gtlabel, should be in same shape with gtlabel - and score value in range (0, 1). anchors (list|tuple): ${anchors_comment} anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} downsample (int): ${downsample_comment} - use_label_smooth(bool): ${use_label_smooth_comment} name (string): the name of yolov3 loss Returns: @@ -449,11 +444,9 @@ def yolov3_loss(x, TypeError: Input x of yolov3_loss must be Variable TypeError: Input gtbox of yolov3_loss must be Variable" TypeError: Input gtlabel of yolov3_loss must be Variable" - TypeError: Input gtscore of yolov3_loss must be Variable" TypeError: Attr anchors of yolov3_loss must be list or tuple TypeError: Attr class_num of yolov3_loss must be an integer TypeError: Attr ignore_thresh of yolov3_loss must be a float number - TypeError: Attr use_label_smooth of yolov3_loss must be a bool value Examples: .. code-block:: python @@ -474,16 +467,12 @@ def yolov3_loss(x, raise TypeError("Input gtbox of yolov3_loss must be Variable") if not isinstance(gtlabel, Variable): raise TypeError("Input gtlabel of yolov3_loss must be Variable") - if not isinstance(gtscore, Variable): - raise TypeError("Input gtscore of yolov3_loss must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): raise TypeError("Attr anchors of yolov3_loss must be list or tuple") if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") if not isinstance(class_num, int): raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(use_label_smooth, bool): - raise TypeError("Attr ues_label_smooth of yolov3 must be a bool value") if not isinstance(ignore_thresh, float): raise TypeError( "Attr ignore_thresh of yolov3_loss must be a float number") @@ -503,7 +492,6 @@ def yolov3_loss(x, "class_num": class_num, "ignore_thresh": ignore_thresh, "downsample": downsample, - "use_label_smooth": use_label_smooth } helper.append_op( @@ -512,7 +500,6 @@ def yolov3_loss(x, "X": x, "GTBox": gtbox, "GTLabel": gtlabel, - "GTScore": gtscore }, outputs={ 'Loss': loss, diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 0e17eb3130..020c113923 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -23,10 +23,6 @@ from op_test import OpTest from paddle.fluid import core -def l1loss(x, y): - return abs(x - y) - - def l2loss(x, y): return 0.5 * (y - x) * (y - x) @@ -70,7 +66,7 @@ def batch_xywh_box_iou(box1, box2): return inter_area / union -def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): +def YOLOv3Loss(x, gtbox, gtlabel, attrs): n, c, h, w = x.shape b = gtbox.shape[1] anchors = attrs['anchors'] @@ -80,14 +76,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] downsample = attrs['downsample'] - use_label_smooth = attrs['use_label_smooth'] input_size = downsample * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') - label_pos = 1.0 - 1.0 / class_num if use_label_smooth else 1.0 - label_neg = 1.0 / class_num if use_label_smooth else 0.0 - pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) @@ -146,22 +138,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ty = gtbox[i, j, 1] * w - gj tw = np.log(gtbox[i, j, 2] * input_size / mask_anchors[an_idx][0]) th = np.log(gtbox[i, j, 3] * input_size / mask_anchors[an_idx][1]) - scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) * gtscore[i, j] + scale = (2.0 - gtbox[i, j, 2] * gtbox[i, j, 3]) loss[i] += sce(x[i, an_idx, gj, gi, 0], tx) * scale loss[i] += sce(x[i, an_idx, gj, gi, 1], ty) * scale loss[i] += l2loss(x[i, an_idx, gj, gi, 2], tw) * scale loss[i] += l2loss(x[i, an_idx, gj, gi, 3], th) * scale - objness[i, an_idx * h * w + gj * w + gi] = gtscore[i, j] + objness[i, an_idx * h * w + gj * w + gi] = 1.0 for label_idx in range(class_num): - loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], label_pos - if label_idx == gtlabel[i, j] else - label_neg) * gtscore[i, j] + loss[i] += sce(x[i, an_idx, gj, gi, 5 + label_idx], + float(label_idx == gtlabel[i, j])) for j in range(mask_num * h * w): if objness[i, j] > 0: - loss[i] += sce(pred_obj[i, j], 1.0) * objness[i, j] + loss[i] += sce(pred_obj[i, j], 1.0) elif objness[i, j] == 0: loss[i] += sce(pred_obj[i, j], 0.0) @@ -176,7 +167,6 @@ class TestYolov3LossOp(OpTest): x = logit(np.random.uniform(0, 1, self.x_shape).astype('float32')) gtbox = np.random.random(size=self.gtbox_shape).astype('float32') gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2]) - gtscore = np.random.random(self.gtbox_shape[:2]).astype('float32') gtmask = np.random.randint(0, 2, self.gtbox_shape[:2]) gtbox = gtbox * gtmask[:, :, np.newaxis] gtlabel = gtlabel * gtmask @@ -187,17 +177,14 @@ class TestYolov3LossOp(OpTest): "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, "downsample": self.downsample, - "use_label_smooth": self.use_label_smooth, } self.inputs = { 'X': x, 'GTBox': gtbox.astype('float32'), 'GTLabel': gtlabel.astype('int32'), - 'GTScore': gtscore.astype('float32') } - loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, gtscore, - self.attrs) + loss, objness, gt_matches = YOLOv3Loss(x, gtbox, gtlabel, self.attrs) self.outputs = { 'Loss': loss, 'ObjectnessMask': objness, @@ -213,7 +200,7 @@ class TestYolov3LossOp(OpTest): self.check_grad_with_place( place, ['X'], 'Loss', - no_grad_set=set(["GTBox", "GTLabel", "GTScore"]), + no_grad_set=set(["GTBox", "GTLabel"]), max_relative_error=0.3) def initTestCase(self): @@ -224,12 +211,6 @@ class TestYolov3LossOp(OpTest): self.downsample = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 5, 4) - self.use_label_smooth = True - - -class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): - def set_label_smooth(self): - self.use_label_smooth = False if __name__ == "__main__": From 56e21c558e37395ead098d588902464cb09c206a Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 28 Jan 2019 17:10:47 +0800 Subject: [PATCH 40/53] add comments and docs. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 7 ++++++- paddle/fluid/operators/yolov3_loss_op.h | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 46374db49a..0d13d8fff4 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -98,7 +98,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "This is a 4-D tensor with shape of [N, C, H, W]." "H and W should be same, and the second dimention(C) stores" "box locations, confidence score and classification one-hot" - "key of each anchor box"); + "keys of each anchor box"); AddInput("GTBox", "The input tensor of ground truth boxes, " "This is a 3-D tensor with shape of [N, max_box_num, 5], " @@ -179,6 +179,11 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { box coordinates (w, h), and sigmoid cross entropy loss is used for box coordinates (x, y), confidence score loss and classification loss. + Each groud truth box find a best matching anchor box in all anchors, + prediction of this anchor box will incur all three parts of losses, and + prediction of anchor boxes with no GT box matched will only incur objectness + loss. + In order to trade off box coordinate losses between big boxes and small boxes, box coordinate losses will be mutiplied by scale weight, which is calculated as follow. diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index 5c9851232d..fce8195668 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -308,13 +308,15 @@ class Yolov3LossKernel : public framework::OpKernel { } } + // If best IoU is greater then ignore_thresh, + // ignore the objectness loss. if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; obj_mask_data[obj_idx] = static_cast(-1); } - // TODO(dengkaipeng): all losses should be calculated if best IoU - // is bigger then truth thresh should be calculated here, but - // currently, truth thresh is an unreachable value as 1.0. + // all losses should be calculated if best IoU + // is bigger then truth thresh, but currently, + // truth thresh is an unreachable value as 1.0. } } } @@ -341,8 +343,6 @@ class Yolov3LossKernel : public framework::OpKernel { an_box.w = anchors[2 * an_idx] / static_cast(input_size); an_box.h = anchors[2 * an_idx + 1] / static_cast(input_size); float iou = CalcBoxIoU(an_box, gt_shift); - // TODO(dengkaipeng): In paper, objectness loss is ignore when - // best IoU > 0.5, but darknet code didn't implement this. if (iou > best_iou) { best_iou = iou; best_n = an_idx; From ae0b0d5f9362b11fb78355d9d56b7f9ff1cc9c6b Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 28 Jan 2019 22:58:46 +0800 Subject: [PATCH 41/53] fix doc. test=develop --- paddle/fluid/operators/yolov3_loss_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 0d13d8fff4..30f0c08463 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -121,7 +121,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "mask for calculate objectness loss in gradient kernel.") .AsIntermediate(); AddOutput("GTMatchMask", - "This is an intermediate tensor with shape if [N, B], " + "This is an intermediate tensor with shape of [N, B], " "B is the max box number of GT boxes. This parameter caches " "matched mask index of each GT boxes for gradient calculate.") .AsIntermediate(); @@ -175,7 +175,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { thresh, the confidence score loss of this anchor box will be ignored. Therefore, the yolov3 loss consist of three major parts, box location loss, - confidence score loss, and classification loss. The L1 loss is used for + confidence score loss, and classification loss. The L2 loss is used for box coordinates (w, h), and sigmoid cross entropy loss is used for box coordinates (x, y), confidence score loss and classification loss. From 733bb82ec0d7ba4bbe9f0ed2aa5c36bc81829fa0 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 29 Jan 2019 14:38:47 +0800 Subject: [PATCH 42/53] downsample -> downsample_ratio. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/yolov3_loss_op.cc | 2 +- paddle/fluid/operators/yolov3_loss_op.h | 41 +++++++++++++----------- python/paddle/fluid/layers/detection.py | 10 +++--- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6c6ac9c7ea..5fdab448cb 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -324,7 +324,7 @@ paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index 30f0c08463..81fd87b4ac 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -135,7 +135,7 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { "The mask index of anchors used in " "current YOLOv3 loss calculation.") .SetDefault(std::vector{}); - AddAttr("downsample", + AddAttr("downsample_ratio", "The downsample ratio from network input to YOLOv3 loss " "input, so 32, 16, 8 should be set for the first, second, " "and thrid YOLOv3 loss operators.") diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/yolov3_loss_op.h index fce8195668..8407d4e6e8 100644 --- a/paddle/fluid/operators/yolov3_loss_op.h +++ b/paddle/fluid/operators/yolov3_loss_op.h @@ -32,7 +32,7 @@ static inline bool LessEqualZero(T x) { } template -static T SCE(T x, T label) { +static T SigmoidCrossEntropy(T x, T label) { return (x > 0 ? x : 0.0) - x * label + std::log(1.0 + std::exp(-std::abs(x))); } @@ -42,7 +42,7 @@ static T L2Loss(T x, T y) { } template -static T SCEGrad(T x, T label) { +static T SigmoidCrossEntropyGrad(T x, T label) { return 1.0 / (1.0 + std::exp(-x)) - label; } @@ -62,7 +62,7 @@ static int GetMaskIndex(std::vector mask, int val) { template struct Box { - float x, y, w, h; + T x, y, w, h; }; template @@ -128,8 +128,8 @@ static void CalcBoxLocationLoss(T* loss, const T* input, Box gt, T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); T scale = (2.0 - gt.w * gt.h); - loss[0] += SCE(input[box_idx], tx) * scale; - loss[0] += SCE(input[box_idx + stride], ty) * scale; + loss[0] += SigmoidCrossEntropy(input[box_idx], tx) * scale; + loss[0] += SigmoidCrossEntropy(input[box_idx + stride], ty) * scale; loss[0] += L2Loss(input[box_idx + 2 * stride], tw) * scale; loss[0] += L2Loss(input[box_idx + 3 * stride], th) * scale; } @@ -145,9 +145,10 @@ static void CalcBoxLocationLossGrad(T* input_grad, const T loss, const T* input, T th = std::log(gt.h * input_size / anchors[2 * an_idx + 1]); T scale = (2.0 - gt.w * gt.h); - input_grad[box_idx] = SCEGrad(input[box_idx], tx) * scale * loss; + input_grad[box_idx] = + SigmoidCrossEntropyGrad(input[box_idx], tx) * scale * loss; input_grad[box_idx + stride] = - SCEGrad(input[box_idx + stride], ty) * scale * loss; + SigmoidCrossEntropyGrad(input[box_idx + stride], ty) * scale * loss; input_grad[box_idx + 2 * stride] = L2LossGrad(input[box_idx + 2 * stride], tw) * scale * loss; input_grad[box_idx + 3 * stride] = @@ -160,7 +161,7 @@ static inline void CalcLabelLoss(T* loss, const T* input, const int index, const int stride) { for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; - loss[0] += SCE(pred, (i == label) ? 1.0 : 0.0); + loss[0] += SigmoidCrossEntropy(pred, (i == label) ? 1.0 : 0.0); } } @@ -172,7 +173,7 @@ static inline void CalcLabelLossGrad(T* input_grad, const T loss, for (int i = 0; i < class_num; i++) { T pred = input[index + i * stride]; input_grad[index + i * stride] = - SCEGrad(pred, (i == label) ? 1.0 : 0.0) * loss; + SigmoidCrossEntropyGrad(pred, (i == label) ? 1.0 : 0.0) * loss; } } @@ -187,11 +188,11 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness, for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; if (obj > 1e-5) { - // positive sample: obj = mixup score - loss[i] += SCE(input[k * w + l], 1.0); + // positive sample: obj = 1 + loss[i] += SigmoidCrossEntropy(input[k * w + l], 1.0); } else if (obj > -0.5) { // negetive sample: obj = 0 - loss[i] += SCE(input[k * w + l], 0.0); + loss[i] += SigmoidCrossEntropy(input[k * w + l], 0.0); } } } @@ -213,9 +214,11 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss, for (int l = 0; l < w; l++) { T obj = objness[k * w + l]; if (obj > 1e-5) { - input_grad[k * w + l] = SCEGrad(input[k * w + l], 1.0) * loss[i]; + input_grad[k * w + l] = + SigmoidCrossEntropyGrad(input[k * w + l], 1.0) * loss[i]; } else if (obj > -0.5) { - input_grad[k * w + l] = SCEGrad(input[k * w + l], 0.0) * loss[i]; + input_grad[k * w + l] = + SigmoidCrossEntropyGrad(input[k * w + l], 0.0) * loss[i]; } } } @@ -256,7 +259,7 @@ class Yolov3LossKernel : public framework::OpKernel { auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); float ignore_thresh = ctx.Attr("ignore_thresh"); - int downsample = ctx.Attr("downsample"); + int downsample_ratio = ctx.Attr("downsample_ratio"); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -264,7 +267,7 @@ class Yolov3LossKernel : public framework::OpKernel { const int an_num = anchors.size() / 2; const int mask_num = anchor_mask.size(); const int b = gt_box->dims()[1]; - int input_size = downsample * h; + int input_size = downsample_ratio * h; const int stride = h * w; const int an_stride = (class_num + 5) * stride; @@ -308,7 +311,7 @@ class Yolov3LossKernel : public framework::OpKernel { } } - // If best IoU is greater then ignore_thresh, + // If best IoU is bigger then ignore_thresh, // ignore the objectness loss. if (best_iou > ignore_thresh) { int obj_idx = (i * mask_num + j) * stride + k * w + l; @@ -388,7 +391,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { auto anchors = ctx.Attr>("anchors"); auto anchor_mask = ctx.Attr>("anchor_mask"); int class_num = ctx.Attr("class_num"); - int downsample = ctx.Attr("downsample"); + int downsample_ratio = ctx.Attr("downsample_ratio"); const int n = input_grad->dims()[0]; const int c = input_grad->dims()[1]; @@ -396,7 +399,7 @@ class Yolov3LossGradKernel : public framework::OpKernel { const int w = input_grad->dims()[3]; const int mask_num = anchor_mask.size(); const int b = gt_match_mask->dims()[1]; - int input_size = downsample * h; + int input_size = downsample_ratio * h; const int stride = h * w; const int an_stride = (class_num + 5) * stride; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index ea130bb279..486503c871 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -416,7 +416,7 @@ def yolov3_loss(x, anchor_mask, class_num, ignore_thresh, - downsample, + downsample_ratio, name=None): """ ${comment} @@ -434,7 +434,7 @@ def yolov3_loss(x, anchor_mask (list|tuple): ${anchor_mask_comment} class_num (int): ${class_num_comment} ignore_thresh (float): ${ignore_thresh_comment} - downsample (int): ${downsample_comment} + downsample_ratio (int): ${downsample_ratio_comment} name (string): the name of yolov3 loss Returns: @@ -456,8 +456,8 @@ def yolov3_loss(x, gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32') anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] anchors = [0, 1, 2] - loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80 - anchors=anchors, ignore_thresh=0.5) + loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80, anchors=anchors, + ignore_thresh=0.5, downsample_ratio=32) """ helper = LayerHelper('yolov3_loss', **locals()) @@ -491,7 +491,7 @@ def yolov3_loss(x, "anchor_mask": anchor_mask, "class_num": class_num, "ignore_thresh": ignore_thresh, - "downsample": downsample, + "downsample_ratio": downsample_ratio, } helper.append_op( From 23d34d1f7e553bdcf4ac1d270f9e828f8cf99baf Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 29 Jan 2019 16:15:38 +0800 Subject: [PATCH 43/53] move yolov3_loss to detection. test=develop --- paddle/fluid/operators/detection/CMakeLists.txt | 1 + paddle/fluid/operators/{ => detection}/yolov3_loss_op.cc | 2 +- paddle/fluid/operators/{ => detection}/yolov3_loss_op.h | 0 3 files changed, 2 insertions(+), 1 deletion(-) rename paddle/fluid/operators/{ => detection}/yolov3_loss_op.cc (99%) rename paddle/fluid/operators/{ => detection}/yolov3_loss_op.h (100%) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index d3a61dc367..cace42bc1b 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) +detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) if(WITH_GPU) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc similarity index 99% rename from paddle/fluid/operators/yolov3_loss_op.cc rename to paddle/fluid/operators/detection/yolov3_loss_op.cc index 81fd87b4ac..2a69ad4b53 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -9,7 +9,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/yolov3_loss_op.h" +#include "paddle/fluid/operators/detection/yolov3_loss_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/yolov3_loss_op.h b/paddle/fluid/operators/detection/yolov3_loss_op.h similarity index 100% rename from paddle/fluid/operators/yolov3_loss_op.h rename to paddle/fluid/operators/detection/yolov3_loss_op.h From b1bdcd4de8b7b0fea2868d664563e425426f6834 Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Mon, 28 Jan 2019 05:34:41 +0100 Subject: [PATCH 44/53] Make separate folders for mkldnn codes test=develop --- cmake/operators.cmake | 4 +-- paddle/fluid/framework/ir/CMakeLists.txt | 32 +++++++++++++------ .../conv_bias_mkldnn_fuse_pass.cc | 2 +- .../{ => mkldnn}/conv_bias_mkldnn_fuse_pass.h | 0 .../conv_elementwise_add_mkldnn_fuse_pass.cc | 2 +- .../conv_elementwise_add_mkldnn_fuse_pass.h | 0 ...elementwise_add_mkldnn_fuse_pass_tester.cc | 2 +- .../conv_relu_mkldnn_fuse_pass.cc | 2 +- .../{ => mkldnn}/conv_relu_mkldnn_fuse_pass.h | 0 .../conv_relu_mkldnn_fuse_pass_tester.cc | 2 +- .../depthwise_conv_mkldnn_pass.cc | 2 +- .../{ => mkldnn}/depthwise_conv_mkldnn_pass.h | 0 .../depthwise_conv_mkldnn_pass_tester.cc | 2 +- .../ir/{ => mkldnn}/mkldnn_placement_pass.cc | 2 +- .../ir/{ => mkldnn}/mkldnn_placement_pass.h | 0 paddle/fluid/operators/activation_op.cc | 2 +- paddle/fluid/operators/mkldnn/CMakeLists.txt | 2 ++ .../{ => mkldnn}/activation_mkldnn_op.cc | 0 .../{ => mkldnn}/batch_norm_mkldnn_op.cc | 0 .../{ => mkldnn}/concat_mkldnn_op.cc | 0 .../operators/{ => mkldnn}/conv_mkldnn_op.cc | 0 .../{ => mkldnn}/conv_transpose_mkldnn_op.cc | 0 .../{ => mkldnn}/dequantize_mkldnn_op.cc | 0 .../elementwise/elementwise_add_mkldnn_op.cc | 0 .../elementwise/elementwise_mul_mkldnn_op.cc | 0 .../operators/{ => mkldnn}/fc_mkldnn_op.cc | 0 .../{ => mkldnn}/gaussian_random_mkldnn_op.cc | 0 .../operators/{ => mkldnn}/lrn_mkldnn_op.cc | 0 .../{ => mkldnn}/mkldnn_activation_op.h | 0 .../operators/{ => mkldnn}/pool_mkldnn_op.cc | 0 .../{ => mkldnn}/quantize_mkldnn_op.cc | 0 .../{ => mkldnn}/softmax_mkldnn_op.cc | 0 .../operators/{ => mkldnn}/sum_mkldnn_op.cc | 0 .../{ => mkldnn}/transpose_mkldnn_op.cc | 0 .../fluid/tests/unittests/CMakeLists.txt | 13 +++----- .../tests/unittests/mkldnn/CMakeLists.txt | 6 ++++ .../fluid/tests/unittests/mkldnn/__init__.py | 13 ++++++++ .../{ => mkldnn}/test_activation_mkldnn_op.py | 4 +-- .../{ => mkldnn}/test_batch_norm_mkldnn_op.py | 4 +-- .../{ => mkldnn}/test_concat_mkldnn_op.py | 2 +- .../test_conv2d_int8_mkldnn_op.py | 4 +-- .../{ => mkldnn}/test_conv2d_mkldnn_op.py | 2 +- .../test_conv2d_transpose_mkldnn_op.py | 2 +- .../{ => mkldnn}/test_conv3d_mkldnn_op.py | 2 +- .../{ => mkldnn}/test_dequantize_mkldnn_op.py | 2 +- .../test_elementwise_add_mkldnn_op.py | 4 +-- .../test_elementwise_mul_mkldnn_op.py | 4 +-- .../{ => mkldnn}/test_fc_mkldnn_op.py | 2 +- .../test_gaussian_random_mkldnn_op.py | 2 +- .../{ => mkldnn}/test_lrn_mkldnn_op.py | 2 +- .../test_pool2d_int8_mkldnn_op.py | 4 +-- .../{ => mkldnn}/test_pool2d_mkldnn_op.py | 2 +- .../{ => mkldnn}/test_quantize_mkldnn_op.py | 2 +- .../{ => mkldnn}/test_sum_mkldnn_op.py | 2 +- .../{ => mkldnn}/test_transpose_mkldnn_op.py | 2 +- 55 files changed, 83 insertions(+), 53 deletions(-) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_bias_mkldnn_fuse_pass.cc (98%) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_bias_mkldnn_fuse_pass.h (100%) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_elementwise_add_mkldnn_fuse_pass.cc (99%) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_elementwise_add_mkldnn_fuse_pass.h (100%) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_elementwise_add_mkldnn_fuse_pass_tester.cc (98%) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_relu_mkldnn_fuse_pass.cc (97%) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_relu_mkldnn_fuse_pass.h (100%) rename paddle/fluid/framework/ir/{ => mkldnn}/conv_relu_mkldnn_fuse_pass_tester.cc (98%) rename paddle/fluid/framework/ir/{ => mkldnn}/depthwise_conv_mkldnn_pass.cc (96%) rename paddle/fluid/framework/ir/{ => mkldnn}/depthwise_conv_mkldnn_pass.h (100%) rename paddle/fluid/framework/ir/{ => mkldnn}/depthwise_conv_mkldnn_pass_tester.cc (98%) rename paddle/fluid/framework/ir/{ => mkldnn}/mkldnn_placement_pass.cc (95%) rename paddle/fluid/framework/ir/{ => mkldnn}/mkldnn_placement_pass.h (100%) create mode 100644 paddle/fluid/operators/mkldnn/CMakeLists.txt rename paddle/fluid/operators/{ => mkldnn}/activation_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/batch_norm_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/concat_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/conv_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/conv_transpose_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/dequantize_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/elementwise/elementwise_add_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/elementwise/elementwise_mul_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/fc_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/gaussian_random_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/lrn_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/mkldnn_activation_op.h (100%) rename paddle/fluid/operators/{ => mkldnn}/pool_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/quantize_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/softmax_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/sum_mkldnn_op.cc (100%) rename paddle/fluid/operators/{ => mkldnn}/transpose_mkldnn_op.cc (100%) create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/__init__.py rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_activation_mkldnn_op.py (94%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_batch_norm_mkldnn_op.py (92%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_concat_mkldnn_op.py (94%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_conv2d_int8_mkldnn_op.py (98%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_conv2d_mkldnn_op.py (91%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_conv2d_transpose_mkldnn_op.py (94%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_conv3d_mkldnn_op.py (91%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_dequantize_mkldnn_op.py (97%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_elementwise_add_mkldnn_op.py (97%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_elementwise_mul_mkldnn_op.py (98%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_fc_mkldnn_op.py (98%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_gaussian_random_mkldnn_op.py (90%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_lrn_mkldnn_op.py (96%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_pool2d_int8_mkldnn_op.py (94%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_pool2d_mkldnn_op.py (90%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_quantize_mkldnn_op.py (97%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_sum_mkldnn_op.py (92%) rename python/paddle/fluid/tests/unittests/{ => mkldnn}/test_transpose_mkldnn_op.py (95%) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 59c40a0e5d..c2d0482856 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -52,8 +52,8 @@ function(op_library TARGET) endif() if(WITH_MKLDNN) string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc) - list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mkldnn/${MKLDNN_FILE}.cc) + list(APPEND mkldnn_cc_srcs mkldnn/${MKLDNN_FILE}.cc) endif() endif() else() diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b118dccd1b..914bcce775 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -10,8 +10,22 @@ function(pass_library TARGET DEST) set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DEPS) + set(targetPrefix "") + + # Get optional argument + set(extraMacroArgs ${ARGN}) + list(LENGTH extraMacroArgs numExtraMacroArgs) + if(numExtraMacroArgs GREATER 0) + list(GET extraMacroArgs 0 targetPrefix) + endif() + cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS}) + if(targetPrefix) + cc_library(${TARGET} SRCS ${targetPrefix}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS}) + else() + cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS}) + endif() + # add more DEST here, such as train, dist and collect USE_PASS into a file automatically. if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference") message(STATUS "add pass ${TARGET} ${DEST}") @@ -62,11 +76,11 @@ foreach (index RANGE 3 6) endforeach() if(WITH_MKLDNN) - pass_library(mkldnn_placement_pass base) - pass_library(depthwise_conv_mkldnn_pass base) - pass_library(conv_bias_mkldnn_fuse_pass inference) - pass_library(conv_relu_mkldnn_fuse_pass inference) - pass_library(conv_elementwise_add_mkldnn_fuse_pass inference) + pass_library(mkldnn_placement_pass base mkldnn) + pass_library(depthwise_conv_mkldnn_pass base mkldnn) + pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) + pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) + pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) endif() cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) @@ -86,7 +100,7 @@ cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framewor cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) if (WITH_MKLDNN) - cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) - cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) - cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) + cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) + cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) + cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) endif () diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc similarity index 98% rename from paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc rename to paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index d4a701e0b1..5d0b294f6f 100644 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h similarity index 100% rename from paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h rename to paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc similarity index 99% rename from paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc rename to paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index a8029e67e6..fb3db81347 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" #include #include #include diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h similarity index 100% rename from paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h rename to paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc similarity index 98% rename from paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc rename to paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 61ba097fd8..9ef5c298b8 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -15,8 +15,8 @@ #include #include -#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc similarity index 97% rename from paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc rename to paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc index e359a3832e..4f4605398a 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h" #include #include #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h similarity index 100% rename from paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h rename to paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc similarity index 98% rename from paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc rename to paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc index 19248b4dfe..06d56f6222 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h" #include #include "paddle/fluid/framework/op_proto_maker.h" diff --git a/paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc similarity index 96% rename from paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc rename to paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc index 19056e18aa..7851e8c84b 100644 --- a/paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h similarity index 100% rename from paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h rename to paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h diff --git a/paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc similarity index 98% rename from paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc rename to paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc index 09d0b15f46..1783e3322b 100644 --- a/paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h" #include diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc similarity index 95% rename from paddle/fluid/framework/ir/mkldnn_placement_pass.cc rename to paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc index 951fcb066c..20e52410ff 100644 --- a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include namespace paddle { diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h similarity index 100% rename from paddle/fluid/framework/ir/mkldnn_placement_pass.h rename to paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 9c5b8604f4..7ec9d2fed5 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" #include -#include "paddle/fluid/operators/mkldnn_activation_op.h" +#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/platform/port.h" namespace paddle { diff --git a/paddle/fluid/operators/mkldnn/CMakeLists.txt b/paddle/fluid/operators/mkldnn/CMakeLists.txt new file mode 100644 index 0000000000..5d468316e8 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/CMakeLists.txt @@ -0,0 +1,2 @@ +include(operators) +register_operators() diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/activation_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc diff --git a/paddle/fluid/operators/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/batch_norm_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc diff --git a/paddle/fluid/operators/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/concat_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/conv_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc diff --git a/paddle/fluid/operators/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/conv_transpose_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc diff --git a/paddle/fluid/operators/dequantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/dequantize_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc diff --git a/paddle/fluid/operators/elementwise/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/elementwise/elementwise_add_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/elementwise/elementwise_add_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/elementwise/elementwise_add_mkldnn_op.cc diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/elementwise/elementwise_mul_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/elementwise/elementwise_mul_mkldnn_op.cc diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/fc_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc diff --git a/paddle/fluid/operators/gaussian_random_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/gaussian_random_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/lrn_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc diff --git a/paddle/fluid/operators/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn/mkldnn_activation_op.h similarity index 100% rename from paddle/fluid/operators/mkldnn_activation_op.h rename to paddle/fluid/operators/mkldnn/mkldnn_activation_op.h diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/pool_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc diff --git a/paddle/fluid/operators/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/quantize_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc diff --git a/paddle/fluid/operators/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/softmax_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc diff --git a/paddle/fluid/operators/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/sum_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc diff --git a/paddle/fluid/operators/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/transpose_mkldnn_op.cc rename to paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7e693c6a41..699181d01d 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1,15 +1,6 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -# The MKLDNN tests are skiped when the MKLDNN flag is OFF -if(NOT WITH_MKLDNN) - foreach(src ${TEST_OPS}) - if(${src} MATCHES ".*_mkldnn_op$") - list(REMOVE_ITEM TEST_OPS ${src}) - endif() - endforeach() -endif(NOT WITH_MKLDNN) - if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_recv_op) list(REMOVE_ITEM TEST_OPS test_dist_transpiler) @@ -123,3 +114,7 @@ endif() if (WITH_NGRAPH) add_subdirectory(ngraph) endif() + +if (WITH_MKLDNN) + add_subdirectory(mkldnn) +endif() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/CMakeLists.txt b/python/paddle/fluid/tests/unittests/mkldnn/CMakeLists.txt new file mode 100644 index 0000000000..f71e04c09a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/__init__.py b/python/paddle/fluid/tests/unittests/mkldnn/__init__.py new file mode 100644 index 0000000000..b94a21a7e4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py similarity index 94% rename from python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index 611d0dd076..ad94a4b21c 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -17,9 +17,9 @@ from __future__ import print_function import unittest import numpy as np import paddle.fluid.core as core -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest from scipy.special import expit -from test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs +from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs class TestMKLDNNReluDim2(TestRelu): diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py similarity index 92% rename from python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py index 1286cee8dc..5fce90372d 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py @@ -19,9 +19,9 @@ import numpy as np import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.framework import grad_var_name -from test_batch_norm_op import TestBatchNormOpInference, TestBatchNormOpTraining, _reference_training, _reference_grad +from paddle.fluid.tests.unittests.test_batch_norm_op import TestBatchNormOpInference, TestBatchNormOpTraining, _reference_training, _reference_grad class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining): diff --git a/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py similarity index 94% rename from python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py index 0f2130f904..1a39974069 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest -from test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3 +from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3 class TestMKLDNNConcatOp(TestConcatOp): diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py similarity index 98% rename from python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index 5ad376cb08..100a03cea0 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -18,8 +18,8 @@ import unittest import numpy as np import paddle.fluid.core as core -from op_test import OpTest -from test_conv2d_op import conv2d_forward_naive, TestConv2dOp +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp def conv2d_forward_refer(input, filter, group, conv_param): diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py similarity index 91% rename from python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py index 438d45b840..0542eef800 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest -from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGroup, TestWith1x1, TestWithInput1x1Filter1x1 +from paddle.fluid.tests.unittests.test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGroup, TestWith1x1, TestWithInput1x1Filter1x1 class TestMKLDNN(TestConv2dOp): diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py similarity index 94% rename from python/paddle/fluid/tests/unittests/test_conv2d_transpose_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py index deefdd09ab..9bcdb7b2a9 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest -from test_conv2d_transpose_op import TestConv2dTransposeOp, TestWithPad, TestWithStride +from paddle.fluid.tests.unittests.test_conv2d_transpose_op import TestConv2dTransposeOp, TestWithPad, TestWithStride class TestMKLDNN(TestConv2dTransposeOp): diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv3d_mkldnn_op.py similarity index 91% rename from python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_conv3d_mkldnn_op.py index f0e1265e14..080b74502f 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv3d_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest -from test_conv3d_op import TestConv3dOp, TestCase1, TestWithGroup1, TestWithGroup2, TestWith1x1, TestWithInput1x1Filter1x1 +from paddle.fluid.tests.unittests.test_conv3d_op import TestConv3dOp, TestCase1, TestWithGroup1, TestWithGroup2, TestWith1x1, TestWithInput1x1Filter1x1 class TestMKLDNN(TestConv3dOp): diff --git a/python/paddle/fluid/tests/unittests/test_dequantize_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_dequantize_mkldnn_op.py similarity index 97% rename from python/paddle/fluid/tests/unittests/test_dequantize_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_dequantize_mkldnn_op.py index 0c5e1abd7c..9a54f927cb 100644 --- a/python/paddle/fluid/tests/unittests/test_dequantize_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_dequantize_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest class TestDeQuantizeOp(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py similarity index 97% rename from python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py index d85cc1f856..c3a42656b7 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py @@ -16,8 +16,8 @@ from __future__ import print_function import unittest import numpy as np import paddle.fluid.core as core -from op_test import OpTest -from test_elementwise_add_op import * +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.test_elementwise_add_op import * ''' Some tests differ from the tests defined in test_elementwise_add_op.py because MKLDNN does not support tensors of number of dimensions 3. diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py similarity index 98% rename from python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py index 536e9a1c58..738715dd70 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py @@ -15,10 +15,10 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator -from test_elementwise_mul_op import * +from paddle.fluid.tests.unittests.test_elementwise_mul_op import * class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): diff --git a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py similarity index 98% rename from python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py index 45951a34d6..84229a5cff 100644 --- a/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest def fully_connected_naive(input, weights, bias_data=None): diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py similarity index 90% rename from python/paddle/fluid/tests/unittests/test_gaussian_random_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py index 9777ec3906..c18bd77bd3 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest -from test_gaussian_random_op import TestGaussianRandomOp +from paddle.fluid.tests.unittests.test_gaussian_random_op import TestGaussianRandomOp class TestMKLDNN(TestGaussianRandomOp): diff --git a/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py similarity index 96% rename from python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py index f6bb2ab7a6..a5e6e116a5 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest -from test_lrn_op import TestLRNOp +from paddle.fluid.tests.unittests.test_lrn_op import TestLRNOp class TestLRNMKLDNNOp(TestLRNOp): diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_int8_mkldnn_op.py similarity index 94% rename from python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_int8_mkldnn_op.py index f4495d0bc8..fca906fecc 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_int8_mkldnn_op.py @@ -19,8 +19,8 @@ import unittest import numpy as np import paddle.fluid.core as core -from op_test import OpTest -from test_pool2d_op import TestPool2D_Op, avg_pool2D_forward_naive, max_pool2D_forward_naive +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.test_pool2d_op import TestPool2D_Op, avg_pool2D_forward_naive, max_pool2D_forward_naive class TestPool2dMKLDNNInt8_Op(TestPool2D_Op): diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py similarity index 90% rename from python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py index 7de5fefc14..6de43dd46e 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest -from test_pool2d_op import TestPool2D_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5 +from paddle.fluid.tests.unittests.test_pool2d_op import TestPool2D_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5 def create_test_mkldnn_class(parent): diff --git a/python/paddle/fluid/tests/unittests/test_quantize_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_quantize_mkldnn_op.py similarity index 97% rename from python/paddle/fluid/tests/unittests/test_quantize_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_quantize_mkldnn_op.py index 9960792864..132f7bd039 100644 --- a/python/paddle/fluid/tests/unittests/test_quantize_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_quantize_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest class TestQuantizeOp(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_sum_mkldnn_op.py similarity index 92% rename from python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_sum_mkldnn_op.py index 55820f31b8..5928047b51 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_sum_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest -from test_sum_op import TestSumOp +from paddle.fluid.tests.unittests.test_sum_op import TestSumOp class TestMKLDNN(TestSumOp): diff --git a/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py similarity index 95% rename from python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py rename to python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py index 0c201b9e4f..4845eefe36 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest -from test_transpose_op import TestTransposeOp +from paddle.fluid.tests.unittests.test_transpose_op import TestTransposeOp class TestTransposeMKLDNN(TestTransposeOp): From 69b7c595d6ba43fe7c79b6f8618355979e236427 Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Tue, 29 Jan 2019 09:57:06 +0100 Subject: [PATCH 45/53] Small fix test=develop --- .../mkldnn}/elementwise_add_mkldnn_op.cc | 0 .../mkldnn}/elementwise_mul_mkldnn_op.cc | 0 paddle/fluid/operators/mkldnn/CMakeLists.txt | 2 -- 3 files changed, 2 deletions(-) rename paddle/fluid/operators/{mkldnn/elementwise => elementwise/mkldnn}/elementwise_add_mkldnn_op.cc (100%) rename paddle/fluid/operators/{mkldnn/elementwise => elementwise/mkldnn}/elementwise_mul_mkldnn_op.cc (100%) delete mode 100644 paddle/fluid/operators/mkldnn/CMakeLists.txt diff --git a/paddle/fluid/operators/mkldnn/elementwise/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/elementwise/elementwise_add_mkldnn_op.cc rename to paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc diff --git a/paddle/fluid/operators/mkldnn/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/elementwise/elementwise_mul_mkldnn_op.cc rename to paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc diff --git a/paddle/fluid/operators/mkldnn/CMakeLists.txt b/paddle/fluid/operators/mkldnn/CMakeLists.txt deleted file mode 100644 index 5d468316e8..0000000000 --- a/paddle/fluid/operators/mkldnn/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -include(operators) -register_operators() From 8f0c2b07f249bb1a8c479b1a2dcd552401fe63e4 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 29 Jan 2019 18:32:46 +0800 Subject: [PATCH 46/53] use embedding=128 bert model for test test=develop --- paddle/fluid/inference/tests/api/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index aa3da397ff..7ecd9e3533 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -128,9 +128,9 @@ inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" SERIAL) -# bert, max_len=20 -set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert20") -download_model_and_data(${BERT_INSTALL_DIR} "bert_model.tar.gz" "bert_data_len20.txt.tar.gz") +# bert, max_len=20, embedding_dim=128 +set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert_emb128") +download_model_and_data(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_data_len20.txt.tar.gz") inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc SERIAL) # anakin From 16d54f7f23cac51988de6937cfdf3d3f66991afa Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 30 Jan 2019 11:24:45 +0800 Subject: [PATCH 47/53] Return parent_idx in beam_search op (#15520) * Refine beam_search_op to output an extra parent_idx tensor. test=develop * Fix the unittest test_beam_search_op. test=develop * Fix the merging mistake. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/beam_search_op.cc | 3 + paddle/fluid/operators/beam_search_op.h | 6 +- paddle/fluid/operators/gather_op.cu | 5 +- paddle/fluid/operators/gather_op.h | 4 +- paddle/fluid/operators/math/beam_search.cc | 8 ++- paddle/fluid/operators/math/beam_search.cu | 68 ++++++++++--------- paddle/fluid/operators/math/beam_search.h | 14 ++-- .../fluid/operators/math/beam_search_test.cc | 3 +- python/paddle/fluid/layers/nn.py | 25 +++++-- .../tests/unittests/test_beam_search_op.py | 5 ++ 11 files changed, 88 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 349460ad98..fe8d6dd425 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -122,7 +122,7 @@ paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) -paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name'], varargs=None, keywords=None, defaults=(0, True, None)) +paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index e78ecc1a12..e93cd8615e 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -51,6 +51,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("selected_scores", "A LoDTensor containing the accumulated scores corresponding to " "Output(selected_ids)."); + AddOutput( + "parent_idx", + "A Tensor preserving the selected_ids' parent indice in pre_ids."); // Attributes stored in AttributeMap AddAttr("level", "the level of LoDTensor"); diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 1b939e742d..f808020cc7 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -41,13 +41,15 @@ class BeamSearchOpKernel : public framework::OpKernel { auto selected_ids = context.Output("selected_ids"); auto selected_scores = context.Output("selected_scores"); + auto* parent_idx = context.Output("parent_idx"); PADDLE_ENFORCE_NOT_NULL(selected_ids); PADDLE_ENFORCE_NOT_NULL(selected_scores); + PADDLE_ENFORCE_NOT_NULL(parent_idx); math::BeamSearchFunctor alg; alg(context.template device_context(), pre_ids, pre_scores, - ids, scores, selected_ids, selected_scores, level, beam_size, end_id, - is_accumulated); + ids, scores, selected_ids, selected_scores, parent_idx, level, + beam_size, end_id, is_accumulated); } }; diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 9f4aef08cd..427ac61858 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -31,7 +31,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { auto *output = ctx.Output("Out"); output->mutable_data(ctx.GetPlace()); - + if (x->numel() == 0) return; GPUGather(ctx.device_context(), *x, *index, output); } }; @@ -45,14 +45,13 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto *Index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); - auto *x = ctx.Input("X"); dX->mutable_data(ctx.GetPlace()); auto dxt = framework::EigenVector::Flatten(*dX); auto &place = *ctx.template device_context() .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); - + if (dO->numel() == 0) return; GPUScatterAssign(ctx.device_context(), *dO, *Index, dX); } }; diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index 2dd726bebb..2e18298cf8 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -35,7 +35,7 @@ class GatherOpKernel : public framework::OpKernel { auto *output = ctx.Output("Out"); output->mutable_data(ctx.GetPlace()); - + if (x->numel() == 0) return; CPUGather(ctx.device_context(), *x, *index, output); } }; @@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel { auto &place = *ctx.template device_context() .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); - + if (dO->numel() == 0) return; ScatterAssign(ctx.device_context(), *dO, *Index, dX); } }; diff --git a/paddle/fluid/operators/math/beam_search.cc b/paddle/fluid/operators/math/beam_search.cc index fb7119273a..69971ef742 100644 --- a/paddle/fluid/operators/math/beam_search.cc +++ b/paddle/fluid/operators/math/beam_search.cc @@ -29,8 +29,9 @@ class BeamSearchFunctor { const framework::LoDTensor *ids, const framework::LoDTensor *scores, framework::LoDTensor *selected_ids, - framework::LoDTensor *selected_scores, size_t level, - size_t beam_size, int end_id, bool is_accumulated) { + framework::LoDTensor *selected_scores, + framework::Tensor *parent_idx, size_t level, size_t beam_size, + int end_id, bool is_accumulated) { auto abs_lod = framework::ToAbsOffset(scores->lod()); auto &high_level = abs_lod[level]; @@ -57,11 +58,13 @@ class BeamSearchFunctor { std::vector({static_cast(num_instances), 1})); selected_ids->Resize(dims); selected_scores->Resize(dims); + parent_idx->Resize({static_cast(num_instances)}); auto *selected_ids_data = selected_ids->mutable_data(platform::CPUPlace()); auto *selected_scores_data = selected_scores->mutable_data(platform::CPUPlace()); + auto *parent_idx_data = parent_idx->mutable_data(platform::CPUPlace()); // fill in data std::vector low_level; @@ -69,6 +72,7 @@ class BeamSearchFunctor { for (auto &items : selected_items) { low_level.push_back(low_offset); for (auto &item : items) { + parent_idx_data[low_offset] = static_cast(low_level.size() - 1); selected_ids_data[low_offset] = item.id; selected_scores_data[low_offset] = item.score; low_offset++; diff --git a/paddle/fluid/operators/math/beam_search.cu b/paddle/fluid/operators/math/beam_search.cu index d94e3023ce..61d021ef62 100644 --- a/paddle/fluid/operators/math/beam_search.cu +++ b/paddle/fluid/operators/math/beam_search.cu @@ -157,10 +157,10 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local, } __device__ __forceinline__ void WriteBack( - int64_t* selected_ids, float* selected_scores, size_t* selected_offsets, - Triple* top_beam_local, const int seq_offset_start, - const int seq_offset_end, const int selected_seq_start, - const int selected_seq_length) { + int64_t* selected_ids, float* selected_scores, int* parent_idx, + size_t* selected_offsets, Triple* top_beam_local, + const int seq_offset_start, const int seq_offset_end, + const int selected_seq_start, const int selected_seq_length) { const int tid = threadIdx.x; // use 1 thread only for each sequence int global_index = selected_seq_start; for (int global_offset = seq_offset_start; global_offset < seq_offset_end; @@ -171,6 +171,7 @@ __device__ __forceinline__ void WriteBack( selected_ids[global_index] = static_cast(top_beam_local[local_index].id); selected_scores[global_index] = top_beam_local[local_index].score; + parent_idx[global_index] = static_cast(global_offset); global_index++; } } @@ -180,11 +181,11 @@ __device__ __forceinline__ void WriteBack( template __device__ void BeamSearchDetails( - int64_t* selected_ids, float* selected_scores, size_t* selected_offsets, - const int64_t* pre_ids, const float* pre_scores, const int64_t* ids, - const float* scores, const int seq_offset_start, const int seq_offset_end, - const int seq_width, int beam_size, int end_id, bool is_accumulated, - int num_used_threads) { + int64_t* selected_ids, float* selected_scores, int* parent_idx, + size_t* selected_offsets, const int64_t* pre_ids, const float* pre_scores, + const int64_t* ids, const float* scores, const int seq_offset_start, + const int seq_offset_end, const int seq_width, int beam_size, int end_id, + bool is_accumulated, int num_used_threads) { __shared__ Triple top_beam[MaxLength]; int num_items = 0; @@ -228,15 +229,15 @@ __device__ void BeamSearchDetails( selected_offsets[0] = 0; } - WriteBack(selected_ids, selected_scores, selected_offsets, top_beam_local, - seq_offset_start, seq_offset_end, selected_seq_start, - selected_seq_length); + WriteBack(selected_ids, selected_scores, parent_idx, selected_offsets, + top_beam_local, seq_offset_start, seq_offset_end, + selected_seq_start, selected_seq_length); } } template __global__ void BeamSearchKernel(int64_t* selected_ids, float* selected_scores, - size_t* selected_offsets, + int* parent_idx, size_t* selected_offsets, const int64_t* pre_ids, const float* pre_scores, const int64_t* ids, const float* scores, const size_t* seq_offsets, @@ -250,24 +251,25 @@ __global__ void BeamSearchKernel(int64_t* selected_ids, float* selected_scores, int seq_offset_end = static_cast(seq_offsets[seq_id + 1]); BeamSearchDetails( - selected_ids, selected_scores, selected_offsets, pre_ids, pre_scores, ids, - scores, seq_offset_start, seq_offset_end, seq_width, beam_size, end_id, - is_accumulated, num_used_threads); + selected_ids, selected_scores, parent_idx, selected_offsets, pre_ids, + pre_scores, ids, scores, seq_offset_start, seq_offset_end, seq_width, + beam_size, end_id, is_accumulated, num_used_threads); } template __global__ void BeamSearchKernelSingle( - int64_t* selected_ids, float* selected_scores, size_t* selected_offsets, - const int64_t* pre_ids, const float* pre_scores, const int64_t* ids, - const float* scores, const int seq_length, const int seq_width, - int beam_size, int end_id, bool is_accumulated, int num_used_threads) { + int64_t* selected_ids, float* selected_scores, int* parent_idx, + size_t* selected_offsets, const int64_t* pre_ids, const float* pre_scores, + const int64_t* ids, const float* scores, const int seq_length, + const int seq_width, int beam_size, int end_id, bool is_accumulated, + int num_used_threads) { const int seq_offset_start = 0; const int seq_offset_end = seq_length; BeamSearchDetails( - selected_ids, selected_scores, selected_offsets, pre_ids, pre_scores, ids, - scores, seq_offset_start, seq_offset_end, seq_width, beam_size, end_id, - is_accumulated, num_used_threads); + selected_ids, selected_scores, parent_idx, selected_offsets, pre_ids, + pre_scores, ids, scores, seq_offset_start, seq_offset_end, seq_width, + beam_size, end_id, is_accumulated, num_used_threads); } static inline int GetNumUsedThreads(const int max_threads_per_seq, @@ -300,8 +302,9 @@ class BeamSearchFunctor { const framework::LoDTensor* ids, const framework::LoDTensor* scores, framework::LoDTensor* selected_ids, - framework::LoDTensor* selected_scores, size_t level, - size_t beam_size, int end_id, bool is_accumulated) { + framework::LoDTensor* selected_scores, + framework::Tensor* parent_idx, size_t level, size_t beam_size, + int end_id, bool is_accumulated) { auto abs_lod = framework::ToAbsOffset(scores->lod()); const int64_t* pre_ids_data = pre_ids->data(); @@ -322,6 +325,8 @@ class BeamSearchFunctor { selected_ids->mutable_data(selected_dims, context.GetPlace()); float* selected_scores_data = selected_scores->mutable_data(selected_dims, context.GetPlace()); + int* parent_idx_data = parent_idx->mutable_data( + {static_cast(num_seqs * beam_size)}, context.GetPlace()); framework::LoD selected_lod(2); selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end()); @@ -339,9 +344,9 @@ class BeamSearchFunctor { CUDA_LAUNCH_KERNEL_HELPER( BeamSearchKernelSingle<<< 1, kMaxThreadsPerSeq, 0, context.stream()>>>( - selected_ids_data, selected_scores_data, selected_offsets, - pre_ids_data, pre_scores_data, ids_data, scores_data, - seq_length, static_cast(seq_width), + selected_ids_data, selected_scores_data, parent_idx_data, + selected_offsets, pre_ids_data, pre_scores_data, ids_data, + scores_data, seq_length, static_cast(seq_width), static_cast(beam_size), static_cast(end_id), is_accumulated, num_used_threads)); } @@ -357,9 +362,9 @@ class BeamSearchFunctor { CUDA_LAUNCH_KERNEL_HELPER( BeamSearchKernel<<< 1, num_seqs * kMaxThreadsPerSeq, 0, context.stream()>>>( - selected_ids_data, selected_scores_data, selected_offsets, - pre_ids_data, pre_scores_data, ids_data, scores_data, - seq_offsets, static_cast(num_seqs), + selected_ids_data, selected_scores_data, parent_idx_data, + selected_offsets, pre_ids_data, pre_scores_data, ids_data, + scores_data, seq_offsets, static_cast(num_seqs), static_cast(seq_width), static_cast(beam_size), end_id, is_accumulated, num_used_threads)); } @@ -379,6 +384,7 @@ class BeamSearchFunctor { {static_cast(selected_lod[1].back()), 1}); selected_ids->Resize(final_selected_dims); selected_scores->Resize(final_selected_dims); + parent_idx->Resize({static_cast(selected_lod[1].back())}); } } }; diff --git a/paddle/fluid/operators/math/beam_search.h b/paddle/fluid/operators/math/beam_search.h index 3cd17f426c..4474e7ea52 100644 --- a/paddle/fluid/operators/math/beam_search.h +++ b/paddle/fluid/operators/math/beam_search.h @@ -104,14 +104,12 @@ class BeamSearchFunctor { * Return false if all the input tensor is empty, in machine translation task * that means no candidates is provided, and the task will stop running. */ - void operator()(const DeviceContext& context, - const framework::LoDTensor* pre_ids, - const framework::LoDTensor* pre_scores, - const framework::LoDTensor* ids, - const framework::LoDTensor* scores, - framework::LoDTensor* selected_ids, - framework::LoDTensor* selected_scores, size_t level, - size_t beam_size, int end_id, bool is_accumulated); + void operator()( + const DeviceContext& context, const framework::LoDTensor* pre_ids, + const framework::LoDTensor* pre_scores, const framework::LoDTensor* ids, + const framework::LoDTensor* scores, framework::LoDTensor* selected_ids, + framework::LoDTensor* selected_scores, framework::Tensor* parent_idx, + size_t level, size_t beam_size, int end_id, bool is_accumulated); }; } // namespace math diff --git a/paddle/fluid/operators/math/beam_search_test.cc b/paddle/fluid/operators/math/beam_search_test.cc index 1c29ee95f6..7ea8eb8b00 100644 --- a/paddle/fluid/operators/math/beam_search_test.cc +++ b/paddle/fluid/operators/math/beam_search_test.cc @@ -93,13 +93,14 @@ void TestBeamSearch() { paddle::framework::LoDTensor selected_ids; paddle::framework::LoDTensor selected_scores; + paddle::framework::LoDTensor parent_idx; size_t level = 0; size_t beam_size = 2; int end_id = 0; paddle::operators::math::BeamSearchFunctor beamsearch; beamsearch(*context, &pre_ids, &pre_scores, &ids, &scores, &selected_ids, - &selected_scores, level, beam_size, end_id, true); + &selected_scores, &parent_idx, level, beam_size, end_id, true); ASSERT_EQ(selected_ids.lod(), selected_scores.lod()); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0dbcf442a3..0e4b5aadc0 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3877,7 +3877,8 @@ def beam_search(pre_ids, end_id, level=0, is_accumulated=True, - name=None): + name=None, + return_parent_idx=False): """ Beam search is a classical algorithm for selecting candidate words in a machine translation task. @@ -3933,10 +3934,16 @@ def beam_search(pre_ids, accumulated scores. name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. + return_parent_idx(bool): Whether to return an extra Tensor variable + preserving the selected_ids' parent indice in pre_ids + in output, which can be used to gather cell states at + the next time step. Returns: - Variable: The LodTensor pair containing the selected ids and the \ - corresponding scores. + Variable: The LodTensor tuple containing the selected ids and the \ + corresponding scores. If :attr:`return_parent_idx` is :attr:`True`, \ + an extra Tensor variable preserving the selected_ids' parent indice \ + is included. Examples: .. code-block:: python @@ -3969,6 +3976,11 @@ def beam_search(pre_ids, selected_scores = helper.create_variable_for_type_inference( dtype=score_type) selected_ids = helper.create_variable_for_type_inference(dtype=id_type) + # parent_idx is a tensor used to gather cell states at the next time + # step. Though lod in selected_ids can also be used to gather by + # sequence_expand, it is not efficient. + # gather_op's index input only supports int32 dtype currently + parent_idx = helper.create_variable_for_type_inference(dtype="int32") helper.append_op( type='beam_search', @@ -3976,6 +3988,7 @@ def beam_search(pre_ids, outputs={ 'selected_ids': selected_ids, 'selected_scores': selected_scores, + 'parent_idx': parent_idx }, attrs={ # TODO(ChunweiYan) to assure other value support @@ -3984,8 +3997,10 @@ def beam_search(pre_ids, 'end_id': end_id, 'is_accumulated': is_accumulated, }) - - return selected_ids, selected_scores + if return_parent_idx: + return selected_ids, selected_scores, parent_idx + else: + return selected_ids, selected_scores def beam_search_decode(ids, scores, beam_size, end_id, name=None): diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_op.py index c28dda4b53..1d9f4b78f3 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_op.py @@ -38,6 +38,7 @@ class BeamSearchOpTester(unittest.TestCase): self._create_pre_ids() self.scope.var('selected_ids') self.scope.var('selected_scores') + self.scope.var('parent_idx') def test_run(self): op = Operator( @@ -48,12 +49,14 @@ class BeamSearchOpTester(unittest.TestCase): scores='scores', selected_ids='selected_ids', selected_scores='selected_scores', + parent_idx='parent_idx', level=0, beam_size=2, end_id=0, ) op.run(self.scope, core.CPUPlace()) selected_ids = self.scope.find_var("selected_ids").get_tensor() selected_scores = self.scope.find_var("selected_scores").get_tensor() + parent_idx = self.scope.find_var("parent_idx").get_tensor() self.assertTrue( np.allclose( np.array(selected_ids), np.array([4, 2, 3, 8])[:, np.newaxis])) @@ -62,6 +65,8 @@ class BeamSearchOpTester(unittest.TestCase): np.array(selected_scores), np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis])) self.assertEqual(selected_ids.lod(), [[0, 2, 4], [0, 1, 2, 3, 4]]) + self.assertTrue( + np.allclose(np.array(parent_idx), np.array([0, 1, 2, 3]))) def _create_pre_ids(self): np_data = np.array([[1, 2, 3, 4]], dtype='int64') From 170842cbb4c61c12a2eb8a93f1cc66fc6ae06f02 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 30 Jan 2019 11:28:14 +0800 Subject: [PATCH 48/53] Some improvements to support bert mixed precision training (#15585) * Some improvements to support bert mixed precision training test=develop * Revert the cast in layer_norm test=develop --- paddle/fluid/operators/dropout_op.cu | 1 + paddle/fluid/operators/gather_op.cu | 7 ++++-- paddle/fluid/operators/lookup_table_op.cu | 8 +++++-- paddle/fluid/operators/reshape_op.cc | 9 ++++++-- paddle/fluid/operators/stack_op.cu | 21 ++++++++++-------- paddle/fluid/operators/transpose_op.cu.cc | 16 ++++++++++---- python/paddle/fluid/initializer.py | 27 +++++++++++++++++++++-- 7 files changed, 68 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index d65491267d..7a6927d3e5 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -114,4 +114,5 @@ REGISTER_OP_CUDA_KERNEL( ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL( dropout_grad, ops::DropoutGradKernel, + ops::DropoutGradKernel, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 427ac61858..490ba9a585 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -60,11 +60,14 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel); + ops::GatherOpCUDAKernel, + ops::GatherOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel); + ops::GatherGradOpCUDAKernel, + ops::GatherGradOpCUDAKernel); diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index fd15539f7b..0af8b9e69c 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/lookup_table_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -193,8 +194,11 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel, - ops::LookupTableCUDAKernel); + ops::LookupTableCUDAKernel, + ops::LookupTableCUDAKernel); REGISTER_OP_CUDA_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel, - ops::LookupTableGradCUDAKernel); + ops::LookupTableGradCUDAKernel, + ops::LookupTableGradCUDAKernel); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 8eab3a6f89..32365d6a96 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -330,6 +330,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel { } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, paddle::framework::DefaultGradOpDescMaker); @@ -356,16 +357,20 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, #ifdef PADDLE_WITH_CUDA REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, - int64_t, ops::ReshapeKernel); + int64_t, ops::ReshapeKernel, plat::float16, + ops::ReshapeKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, - int64_t, ops::ReshapeKernel); + int64_t, ops::ReshapeKernel, plat::float16, + ops::ReshapeKernel); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel, plat::float16, ops::ReshapeGradKernel); #endif diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu index bf2a9e5b3d..24d0b2f906 100644 --- a/paddle/fluid/operators/stack_op.cu +++ b/paddle/fluid/operators/stack_op.cu @@ -17,13 +17,16 @@ namespace plat = paddle::platform; namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel, - ops::StackKernel, - ops::StackKernel, - ops::StackKernel); +REGISTER_OP_CUDA_KERNEL( + stack, ops::StackKernel, + ops::StackKernel, + ops::StackKernel, + ops::StackKernel, + ops::StackKernel); -REGISTER_OP_CUDA_KERNEL(stack_grad, - ops::StackGradKernel, - ops::StackGradKernel, - ops::StackGradKernel, - ops::StackGradKernel); +REGISTER_OP_CUDA_KERNEL( + stack_grad, ops::StackGradKernel, + ops::StackGradKernel, + ops::StackGradKernel, + ops::StackGradKernel, + ops::StackGradKernel); diff --git a/paddle/fluid/operators/transpose_op.cu.cc b/paddle/fluid/operators/transpose_op.cu.cc index b4025350fa..915774e5f3 100644 --- a/paddle/fluid/operators/transpose_op.cu.cc +++ b/paddle/fluid/operators/transpose_op.cu.cc @@ -15,19 +15,27 @@ limitations under the License. */ #include "paddle/fluid/operators/transpose_op.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; + REGISTER_OP_CUDA_KERNEL( transpose, ops::TransposeKernel, - ops::TransposeKernel); + ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose_grad, ops::TransposeGradKernel, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); REGISTER_OP_CUDA_KERNEL( transpose2, ops::TransposeKernel, - ops::TransposeKernel); + ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose2_grad, ops::TransposeGradKernel, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 4f434328e4..5be21ff7f7 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -366,17 +366,40 @@ class TruncatedNormalInitializer(Initializer): # Initialization Ops should be prepended and not appended if self._seed == 0: self._seed = block.program.random_seed + + # to be compatible of fp16 initalizers + if var.dtype == VarDesc.VarType.FP16: + out_dtype = VarDesc.VarType.FP32 + out_var = block.create_var( + name=unique_name.generate(".".join( + ['truncated_gaussian_random', 'tmp'])), + shape=var.shape, + dtype=out_dtype, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False) + else: + out_dtype = var.dtype + out_var = var + op = block._prepend_op( type="truncated_gaussian_random", - outputs={"Out": var}, + outputs={"Out": out_var}, attrs={ "shape": var.shape, - "dtype": int(var.dtype), + "dtype": out_dtype, "mean": self._mean, "std": self._std_dev, "seed": self._seed }, stop_gradient=True) + + if var.dtype == VarDesc.VarType.FP16: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) var.op = op return op From c4b9eac11af34d340db876fae54d93aee427e5d6 Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 29 Jan 2019 23:37:04 -0600 Subject: [PATCH 49/53] fix threshold_relu_op (#15594) test=develop --- python/paddle/fluid/layers/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 6c18af7283..3dcf9dc069 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -135,7 +135,7 @@ def thresholded_relu(x, threshold=None): if val is not None: kwargs[name] = val - _thresholded_relu_(**kwargs) + return _thresholded_relu_(**kwargs) thresholded_relu.__doc__ = _thresholded_relu_.__doc__ + """ From 294d594450c9168995e1cc27caf86dddf98993f3 Mon Sep 17 00:00:00 2001 From: Haihao Shen Date: Wed, 30 Jan 2019 14:20:22 +0800 Subject: [PATCH 50/53] Enable performance measurement in INT8 calibration unit test (#15560) * Enable performance measurement in INT8 calibration unit test --- .../fluid/contrib/tests/test_calibration.py | 144 +++++++++++++----- 1 file changed, 106 insertions(+), 38 deletions(-) diff --git a/python/paddle/fluid/contrib/tests/test_calibration.py b/python/paddle/fluid/contrib/tests/test_calibration.py index f07fefe7e0..cd6b7ba166 100644 --- a/python/paddle/fluid/contrib/tests/test_calibration.py +++ b/python/paddle/fluid/contrib/tests/test_calibration.py @@ -19,10 +19,8 @@ import sys import random import paddle import paddle.fluid as fluid -import argparse import functools import contextlib -import paddle.fluid.profiler as profiler from paddle.dataset.common import download from PIL import Image, ImageEnhance import math @@ -43,7 +41,7 @@ img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) -# TODO(guomingz): Remove duplicated code from line 45 ~ line 114 +# TODO(guomingz): Remove duplicated code from resize_short, crop_image, process_image, _reader_creator def resize_short(img, target_size): percent = float(target_size) / min(img.size[0], img.size[1]) resized_width = int(round(img.size[0] * percent)) @@ -123,16 +121,37 @@ class TestCalibrationForResnet50(unittest.TestCase): self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + self.int8_download) - data_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/calibration_test_data.tar.gz' - data_md5 = '1b6c1c434172cca1bf9ba1e4d7a3157d' - self.data_cache_folder = self.download_data(data_url, data_md5, "data") + data_urls = [] + data_md5s = [] + self.data_cache_folder = '' + if os.environ.get('DATASET') == 'full': + data_urls.append( + 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa' + ) + data_md5s.append('60f6525b0e1d127f345641d75d41f0a8') + data_urls.append( + 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab' + ) + data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5') + self.data_cache_folder = self.download_data(data_urls, data_md5s, + "full_data", False) + else: + data_urls.append( + 'http://paddle-inference-dist.cdn.bcebos.com/int8/calibration_test_data.tar.gz' + ) + data_md5s.append('1b6c1c434172cca1bf9ba1e4d7a3157d') + self.data_cache_folder = self.download_data(data_urls, data_md5s, + "small_data", False) # reader/decorator.py requires the relative path to the data folder cmd = 'rm -rf {0} && ln -s {1} {0}'.format("data", self.data_cache_folder) os.system(cmd) - self.iterations = 50 + self.batch_size = 1 + self.sample_iterations = 50 + self.infer_iterations = 50000 if os.environ.get( + 'DATASET') == 'full' else 50 def cache_unzipping(self, target_folder, zip_path): if not os.path.exists(target_folder): @@ -140,20 +159,44 @@ class TestCalibrationForResnet50(unittest.TestCase): zip_path) os.system(cmd) - def download_data(self, data_url, data_md5, folder_name): - download(data_url, self.int8_download, data_md5) + def download_data(self, data_urls, data_md5s, folder_name, is_model=True): data_cache_folder = os.path.join(self.cache_folder, folder_name) - file_name = data_url.split('/')[-1] - zip_path = os.path.join(self.cache_folder, file_name) + zip_path = '' + if os.environ.get('DATASET') == 'full': + file_names = [] + for i in range(0, len(data_urls)): + download(data_urls[i], self.int8_download, data_md5s[i]) + file_names.append(data_urls[i].split('/')[-1]) + + zip_path = os.path.join(self.cache_folder, + 'full_imagenet_val.tar.gz') + if not os.path.exists(zip_path): + cat_command = 'cat' + for file_name in file_names: + cat_command += ' ' + os.path.join(self.cache_folder, + file_name) + cat_command += ' > ' + zip_path + os.system(cat_command) + + if os.environ.get('DATASET') != 'full' or is_model: + download(data_urls[0], self.int8_download, data_md5s[0]) + file_name = data_urls[0].split('/')[-1] + zip_path = os.path.join(self.cache_folder, file_name) + + print('Data is downloaded at {0}').format(zip_path) self.cache_unzipping(data_cache_folder, zip_path) return data_cache_folder - def download_resnet50_model(self): + def download_model(self): # resnet50 fp32 data - data_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/resnet50_int8_model.tar.gz' - data_md5 = '4a5194524823d9b76da6e738e1367881' - self.model_cache_folder = self.download_data(data_url, data_md5, + data_urls = [ + 'http://paddle-inference-dist.cdn.bcebos.com/int8/resnet50_int8_model.tar.gz' + ] + data_md5s = ['4a5194524823d9b76da6e738e1367881'] + self.model_cache_folder = self.download_data(data_urls, data_md5s, "resnet50_fp32") + self.model = "ResNet-50" + self.algo = "direct" def run_program(self, model_path, generate_int8=False, algo='direct'): image_shape = [3, 224, 224] @@ -169,17 +212,17 @@ class TestCalibrationForResnet50(unittest.TestCase): t = fluid.transpiler.InferenceTranspiler() t.transpile(infer_program, fluid.CPUPlace()) - val_reader = paddle.batch(val(), batch_size=1) + val_reader = paddle.batch(val(), self.batch_size) + iterations = self.infer_iterations if generate_int8: int8_model = os.path.join(os.getcwd(), "calibration_out") + iterations = self.sample_iterations if os.path.exists(int8_model): os.system("rm -rf " + int8_model) os.system("mkdir " + int8_model) - print("Start calibration ...") - calibrator = int8_utility.Calibrator( program=infer_program, pretrained_model=model_path, @@ -191,6 +234,7 @@ class TestCalibrationForResnet50(unittest.TestCase): test_info = [] cnt = 0 + periods = [] for batch_id, data in enumerate(val_reader()): image = np.array( [x[0].reshape(image_shape) for x in data]).astype("float32") @@ -202,21 +246,28 @@ class TestCalibrationForResnet50(unittest.TestCase): if op.has_attr("use_mkldnn"): op._set_attr("use_mkldnn", True) + t1 = time.time() _, acc1, _ = exe.run( running_program, feed={feed_dict[0]: image, feed_dict[1]: label}, fetch_list=fetch_targets) + t2 = time.time() + period = t2 - t1 + periods.append(period) + if generate_int8: calibrator.sample_data() test_info.append(np.mean(acc1) * len(data)) cnt += len(data) - if batch_id != self.iterations - 1: - continue + if (batch_id + 1) % 100 == 0: + print("{0} images,".format(batch_id + 1)) + sys.stdout.flush() - break + if (batch_id + 1) == iterations: + break if generate_int8: calibrator.save_int8_model() @@ -225,32 +276,49 @@ class TestCalibrationForResnet50(unittest.TestCase): "Calibration is done and the corresponding files are generated at {}". format(os.path.abspath("calibration_out"))) else: - return np.sum(test_info) / cnt + throughput = cnt / np.sum(periods) + latency = np.average(periods) + acc1 = np.sum(test_info) / cnt + return (throughput, latency, acc1) def test_calibration(self): - self.download_resnet50_model() - fp32_acc1 = self.run_program(self.model_cache_folder + "/model") - self.run_program(self.model_cache_folder + "/model", True) - int8_acc1 = self.run_program("calibration_out") + self.download_model() + print("Start FP32 inference for {0} on {1} images ...").format( + self.model, self.infer_iterations) + (fp32_throughput, fp32_latency, + fp32_acc1) = self.run_program(self.model_cache_folder + "/model") + print("Start INT8 calibration for {0} on {1} images ...").format( + self.model, self.sample_iterations) + self.run_program( + self.model_cache_folder + "/model", True, algo=self.algo) + print("Start INT8 inference for {0} on {1} images ...").format( + self.model, self.infer_iterations) + (int8_throughput, int8_latency, + int8_acc1) = self.run_program("calibration_out") delta_value = np.abs(fp32_acc1 - int8_acc1) self.assertLess(delta_value, 0.01) + print( + "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + format(self.model, self.batch_size, fp32_throughput, fp32_latency, + fp32_acc1)) + print( + "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + format(self.model, self.batch_size, int8_throughput, int8_latency, + int8_acc1)) + sys.stdout.flush() class TestCalibrationForMobilenetv1(TestCalibrationForResnet50): - def download_mobilenetv1_model(self): + def download_model(self): # mobilenetv1 fp32 data - data_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' - data_md5 = '13892b0716d26443a8cdea15b3c6438b' - self.model_cache_folder = self.download_data(data_url, data_md5, + data_urls = [ + 'http://paddle-inference-dist.cdn.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + self.model_cache_folder = self.download_data(data_urls, data_md5s, "mobilenetv1_fp32") - - def test_calibration(self): - self.download_mobilenetv1_model() - fp32_acc1 = self.run_program(self.model_cache_folder + "/model") - self.run_program(self.model_cache_folder + "/model", True, algo='KL') - int8_acc1 = self.run_program("calibration_out") - delta_value = np.abs(fp32_acc1 - int8_acc1) - self.assertLess(delta_value, 0.01) + self.model = "MobileNet-V1" + self.algo = "KL" if __name__ == '__main__': From 90df7ff3789869bd4d9161c2914eedc8521c4703 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 30 Jan 2019 14:36:35 +0800 Subject: [PATCH 51/53] transpiler.py code clean (#15555) * move var strusted to vars_distributed.py, add optimizer's block name, test=develop * rename optimzier's seems complex, revert it, test=develop * replace * with details, test=develop --- .../fluid/transpiler/details/__init__.py | 1 + .../transpiler/details/vars_distributed.py | 269 ++++++++++++++++++ .../fluid/transpiler/distribute_transpiler.py | 268 +---------------- 3 files changed, 279 insertions(+), 259 deletions(-) create mode 100644 python/paddle/fluid/transpiler/details/vars_distributed.py diff --git a/python/paddle/fluid/transpiler/details/__init__.py b/python/paddle/fluid/transpiler/details/__init__.py index f33c05ed2f..82d0d336e5 100644 --- a/python/paddle/fluid/transpiler/details/__init__.py +++ b/python/paddle/fluid/transpiler/details/__init__.py @@ -17,3 +17,4 @@ from __future__ import print_function from .program_utils import * from .ufind import * from .checkport import * +from .vars_distributed import * diff --git a/python/paddle/fluid/transpiler/details/vars_distributed.py b/python/paddle/fluid/transpiler/details/vars_distributed.py new file mode 100644 index 0000000000..05e7f6e3e7 --- /dev/null +++ b/python/paddle/fluid/transpiler/details/vars_distributed.py @@ -0,0 +1,269 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +from paddle.fluid.framework import Variable + + +class VarStruct(object): + """ + record part properties of a Variable in python. + """ + + def __init__(self, name, shape, dtype, type, lod_level, persistable): + self.name = name + self.shape = shape + self.dtype = dtype + self.type = type + self.lod_level = lod_level + self.persistable = persistable + + +class VarDistributed(object): + """ + a class to record the var distributed on parameter servers. + the class will record the relationship between origin var and slice var. + the slice var's properties, such as type/shape/offset/endpoint. + """ + + def __init__(self, + origin_var, + slice_var, + is_slice=None, + block_id=None, + offset=None, + vtype=None, + endpoint=None): + """ + Args: + origin_var(Variable|VarStruct): origin var properties + slice_var(Variable|VarStruct): slice var properties + is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. + block_id(int|None): the number about the slice var. + offset(int|None): if the slice var is sliced, offset is the numel before the var. + vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. + endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" + """ + + if isinstance(origin_var, Variable): + self.origin = self.__create_var_struct(origin_var) + else: + self.origin = origin_var + + if isinstance(slice_var, Variable): + self.slice = self.__create_var_struct(slice_var) + else: + self.slice = slice_var + + if self.equal(self.origin, self.slice): + self.is_slice = False + self.block_id = 0 + self.offset = 0 + else: + self.is_slice = True + self.block_id = 0 + self.offset = 0 + + if is_slice is not None: + self.is_slice = is_slice + if block_id is not None: + self.block_id = block_id + if offset is not None: + self.offset = offset + + self.vtype = vtype + self.endpoint = endpoint + + @staticmethod + def __create_var_struct(var): + return VarStruct(var.name, var.shape, var.dtype, var.type, + var.lod_level, var.persistable) + + @staticmethod + def equal(var1, var2): + """ + the two var is equal or not. + Returns: + bool: equal will return True else False + """ + assert isinstance(var1, VarStruct) and isinstance(var2, VarStruct) + + return var1.name == var2.name and \ + var1.type == var2.type and \ + var1.shape == var2.shape and \ + var1.dtype == var2.dtype and \ + var1.lod_level == var2.lod_level and \ + var1.persistable == var2.persistable + + def __str__(self): + origin_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})". \ + format(i="{", e="}", name=self.origin.name, type=self.origin.type, + shape=self.origin.shape, dtype=self.origin.dtype) + + slice_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})" \ + ".slice({is_slice}).block({block_id}).offset({offset})". \ + format(i="{", e="}", name=self.slice.name, type=self.slice.type, + shape=self.slice.shape, dtype=self.slice.dtype, + is_slice=self.is_slice, block_id=self.block_id, offset=self.offset) + + return "var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} ".format( + self.vtype, origin_var_str, slice_var_str, self.endpoint) + + +class VarsDistributed(object): + """ + a gather about VarDistributed with many methods to find distributed vars. + through the class, we can get overview about the distributed parameters on parameter servers. + this class may centralized and convenient for developer to manage and get variable's distribute. + other module can also use this to find variables such io.py. + """ + + def __init__(self): + self.distributed_vars = [] + + def add_distributed_var(self, + origin_var, + slice_var, + is_slice=None, + block_id=None, + offset=None, + vtype=None, + endpoint=None): + """ + add distributed var in this. + + Args: + origin_var(Variable|VarStruct): origin var properties + slice_var(Variable|VarStruct): slice var properties + is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. + block_id(int|None): the number about the slice var. + offset(int|None): if the slice var is sliced, offset is the numel before the var. + vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. + endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" + Returns: + None + """ + self.distributed_vars.append( + VarDistributed(origin_var, slice_var, is_slice, block_id, offset, + vtype, endpoint)) + + def get_distributed_var_by_slice(self, var_name): + """ + get distributed var by conditions. + + Args: + var_name(str): slice var name, such as "w.traier0.block1" + Returns: + VarDistributed: distributed var. + """ + for dist_var in self.distributed_vars: + if dist_var.slice.name == var_name: + return dist_var + return None + + @staticmethod + def equal(var1, var2): + """ + the two var is equal or not. + Returns: + bool: equal will return True else False + """ + return var1.name == var2.name and \ + var1.type == var2.type and \ + var1.shape == var2.shape and \ + var1.dtype == var2.dtype and \ + var1.lod_level == var2.lod_level and \ + var1.persistable == var2.persistable + + def get_distributed_var_by_origin_and_ep(self, origin_var_name, endpoint): + """ + get distributed var by conditions. + + Args: + origin_var_name(str): + endpoint(str): the parameter endpoint, such as "127.0.0.1:1001" + Returns: + VarDistributed: distributed var. + """ + for dist_var in self.distributed_vars: + if dist_var.origin.name == origin_var_name and dist_var.endpoint == endpoint: + return dist_var + return None + + def get_distributed_vars_by_vtypes(self, vtypes, groupby=False): + """ + get distributed vars by conditions. + + Args: + vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" + groupby(bool|False): group by origin var or not. + + Returns: + list: distributed var list. + dict: distributed var map when groupby=True + """ + vtype_vars = [] + for var in self.distributed_vars: + if var.vtype in vtypes: + vtype_vars.append(var) + if not groupby: + return vtype_vars + + params_map = {} + for var in vtype_vars: + origin_var_name = var.origin.name + + if origin_var_name in params_map.keys(): + optimizers = params_map.get(origin_var_name) + else: + optimizers = [] + optimizers.append(var) + params_map[origin_var_name] = optimizers + return params_map + + def get_distributed_vars_by_ep(self, endpoint, vtype=None): + """ + get distributed vars by conditions. + + Args: + endpoint(str): the parameter server endpoint, such as "127.0.0.1:2001" + vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" + + Returns: + list: distributed var list. + """ + endpoint_vars = [] + for var in self.distributed_vars: + if var.endpoint == endpoint: + endpoint_vars.append(var) + if not vtype: + return endpoint_vars + + vtype_vars = [] + for var in endpoint_vars: + if var.vtype == vtype: + vtype_vars.append(var) + return vtype_vars + + def overview(self): + """ + get the overview string about all params on all parameter servers. + + Returns: + Str: overview string. + + """ + vars_str = [] + for var in self.distributed_vars: + vars_str.append(str(var)) + return "\n".join(vars_str) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index e58f34e375..a3293afbbd 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -30,19 +30,23 @@ Steps to transpile pserver: 5. add listen_and_serv op """ +import sys import math -import numpy as np +from functools import reduce + import collections +import six import logging +import numpy as np + from .ps_dispatcher import RoundRobin, PSDispatcher from .. import core, framework, unique_name from ..framework import Program, default_main_program, \ - default_startup_program, Block, \ - Parameter, Variable, grad_var_name -from .details import * + default_startup_program, Block, Parameter, grad_var_name +from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed +from .details import delete_ops, find_op_by_output_arg from ..distribute_lookup_table import find_distributed_lookup_table -from functools import reduce LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" @@ -62,260 +66,6 @@ def log(*args): print(args) -class VarStruct(object): - """ - record part properties of a Variable in python. - """ - - def __init__(self, name, shape, dtype, type, lod_level, persistable): - self.name = name - self.shape = shape - self.dtype = dtype - self.type = type - self.lod_level = lod_level - self.persistable = persistable - - -class VarDistributed(object): - """ - a class to record the var distributed on parameter servers. - the class will record the relationship between origin var and slice var. - the slice var's properties, such as type/shape/offset/endpoint. - """ - - def __init__(self, - origin_var, - slice_var, - is_slice=None, - block_id=None, - offset=None, - vtype=None, - endpoint=None): - """ - Args: - origin_var(Variable|VarStruct): origin var properties - slice_var(Variable|VarStruct): slice var properties - is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. - block_id(int|None): the number about the slice var. - offset(int|None): if the slice var is sliced, offset is the numel before the var. - vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. - endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" - """ - - if isinstance(origin_var, Variable): - self.origin = self.__create_var_struct(origin_var) - else: - self.origin = origin_var - - if isinstance(slice_var, Variable): - self.slice = self.__create_var_struct(slice_var) - else: - self.slice = slice_var - - if self.equal(self.origin, self.slice): - self.is_slice = False - self.block_id = 0 - self.offset = 0 - else: - self.is_slice = True - self.block_id = 0 - self.offset = 0 - - if is_slice is not None: - self.is_slice = is_slice - if block_id is not None: - self.block_id = block_id - if offset is not None: - self.offset = offset - - self.vtype = vtype - self.endpoint = endpoint - - @staticmethod - def __create_var_struct(var): - return VarStruct(var.name, var.shape, var.dtype, var.type, - var.lod_level, var.persistable) - - @staticmethod - def equal(var1, var2): - """ - the two var is equal or not. - Returns: - bool: equal will return True else False - """ - assert isinstance(var1, VarStruct) and isinstance(var2, VarStruct) - - return var1.name == var2.name and \ - var1.type == var2.type and \ - var1.shape == var2.shape and \ - var1.dtype == var2.dtype and \ - var1.lod_level == var2.lod_level and \ - var1.persistable == var2.persistable - - def __str__(self): - origin_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})". \ - format(i="{", e="}", name=self.origin.name, type=self.origin.type, - shape=self.origin.shape, dtype=self.origin.dtype) - - slice_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})" \ - ".slice({is_slice}).block({block_id}).offset({offset})". \ - format(i="{", e="}", name=self.slice.name, type=self.slice.type, - shape=self.slice.shape, dtype=self.slice.dtype, - is_slice=self.is_slice, block_id=self.block_id, offset=self.offset) - - return "var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} ".format( - self.vtype, origin_var_str, slice_var_str, self.endpoint) - - -class VarsDistributed(object): - """ - a gather about VarDistributed with many methods to find distributed vars. - through the class, we can get overview about the distributed parameters on parameter servers. - this class may centralized and convenient for developer to manage and get variable's distribute. - other module can also use this to find variables such io.py. - """ - - def __init__(self): - self.distributed_vars = [] - - def add_distributed_var(self, - origin_var, - slice_var, - is_slice=None, - block_id=None, - offset=None, - vtype=None, - endpoint=None): - """ - add distributed var in this. - - Args: - origin_var(Variable|VarStruct): origin var properties - slice_var(Variable|VarStruct): slice var properties - is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. - block_id(int|None): the number about the slice var. - offset(int|None): if the slice var is sliced, offset is the numel before the var. - vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. - endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" - Returns: - None - """ - self.distributed_vars.append( - VarDistributed(origin_var, slice_var, is_slice, block_id, offset, - vtype, endpoint)) - - def get_distributed_var_by_slice(self, var_name): - """ - get distributed var by conditions. - - Args: - var_name(str): slice var name, such as "w.traier0.block1" - Returns: - VarDistributed: distributed var. - """ - for dist_var in self.distributed_vars: - if dist_var.slice.name == var_name: - return dist_var - return None - - @staticmethod - def equal(var1, var2): - """ - the two var is equal or not. - Returns: - bool: equal will return True else False - """ - return var1.name == var2.name and \ - var1.type == var2.type and \ - var1.shape == var2.shape and \ - var1.dtype == var2.dtype and \ - var1.lod_level == var2.lod_level and \ - var1.persistable == var2.persistable - - def get_distributed_var_by_origin_and_ep(self, origin_var_name, endpoint): - """ - get distributed var by conditions. - - Args: - origin_var_name(str): - endpoint(str): the parameter endpoint, such as "127.0.0.1:1001" - Returns: - VarDistributed: distributed var. - """ - for dist_var in self.distributed_vars: - if dist_var.origin.name == origin_var_name and dist_var.endpoint == endpoint: - return dist_var - return None - - def get_distributed_vars_by_vtypes(self, vtypes, groupby=False): - """ - get distributed vars by conditions. - - Args: - vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" - groupby(bool|False): group by origin var or not. - - Returns: - list: distributed var list. - dict: distributed var map when groupby=True - """ - vtype_vars = [] - for var in self.distributed_vars: - if var.vtype in vtypes: - vtype_vars.append(var) - if not groupby: - return vtype_vars - - params_map = {} - for var in vtype_vars: - origin_var_name = var.origin.name - - if origin_var_name in params_map.keys(): - optimizers = params_map.get(origin_var_name) - else: - optimizers = [] - optimizers.append(var) - params_map[origin_var_name] = optimizers - return params_map - - def get_distributed_vars_by_ep(self, endpoint, vtype=None): - """ - get distributed vars by conditions. - - Args: - endpoint(str): the parameter server endpoint, such as "127.0.0.1:2001" - vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" - - Returns: - list: distributed var list. - """ - endpoint_vars = [] - for var in self.distributed_vars: - if var.endpoint == endpoint: - endpoint_vars.append(var) - if not vtype: - return endpoint_vars - - vtype_vars = [] - for var in endpoint_vars: - if var.vtype == vtype: - vtype_vars.append(var) - return vtype_vars - - def overview(self): - """ - get the overview string about all params on all parameter servers. - - Returns: - Str: overview string. - - """ - vars_str = [] - for var in self.distributed_vars: - vars_str.append(str(var)) - return "\n".join(vars_str) - - class VarBlock: def __init__(self, varname, offset, size): self.varname = varname From 312500dcb509ff40d990f1180e92ff333dd37821 Mon Sep 17 00:00:00 2001 From: mozga-intel Date: Wed, 30 Jan 2019 07:51:26 +0100 Subject: [PATCH 52/53] Enable pool2d operator for a ngraph engine (#15395) * Enable pool2d operator for a ngraph engine test=develop * Update test=develop --- .../fluid/operators/ngraph/ngraph_bridge.cc | 2 + paddle/fluid/operators/ngraph/ngraph_ops.h | 1 + paddle/fluid/operators/ngraph/ops/pool2d_op.h | 174 ++++++++++++++++++ .../unittests/ngraph/test_pool2d_ngraph_op.py | 51 +++++ 4 files changed, 228 insertions(+) create mode 100644 paddle/fluid/operators/ngraph/ops/pool2d_op.h create mode 100644 python/paddle/fluid/tests/unittests/ngraph/test_pool2d_ngraph_op.py diff --git a/paddle/fluid/operators/ngraph/ngraph_bridge.cc b/paddle/fluid/operators/ngraph/ngraph_bridge.cc index d6e897ed46..13b168ce45 100644 --- a/paddle/fluid/operators/ngraph/ngraph_bridge.cc +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.cc @@ -38,6 +38,8 @@ std::map +#include + +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildPool2dNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto x_shape = x->get_shape(); + + std::string pooling_type = op_attrs.Get("pooling_type"); + std::vector ksize = op_attrs.Get>("ksize"); + std::vector strides = op_attrs.Get>("strides"); + std::vector paddings = op_attrs.Get>("paddings"); + + PADDLE_ENFORCE_EQ(x_shape.size() - 2, ksize.size(), + "Handling 2d pooling only"); + + if (op_attrs.Get("global_pooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(x_shape.at(i + 2)); + } + } + + ngraph::Shape ng_padding_below{static_cast(paddings.at(0)), + static_cast(paddings.at(1))}; + ngraph::Shape ng_padding_above{static_cast(paddings.at(0)), + static_cast(paddings.at(1))}; + ngraph::Shape ng_ksize_shape{static_cast(ksize.at(0)), + static_cast(ksize.at(1))}; + ngraph::Strides ng_strides{static_cast(strides.at(0)), + static_cast(strides.at(1))}; + + auto ComputeCeiledOutput = [](size_t in, size_t k, size_t p, size_t s) { + return (in - k + 2 * p) / s + 1; + }; + + if (op_attrs.Get("ceil_mode")) { + auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map); + auto dummpy_shape = dummy_out->get_shape(); + for (size_t i = 0; i < ng_padding_above.size(); ++i) { + auto desired_size = ComputeCeiledOutput(x_shape[i + 2], ksize[i], + paddings[i], strides[i]); + if (desired_size != dummpy_shape[i + 2]) { + ng_padding_above[i] += strides[i]; + } + } + } + + bool padding_exclusive = op_attrs.Get("exclusive"); + if (pooling_type == "max") { + auto pool2d = std::make_shared( + x, ng_ksize_shape, ng_strides, ng_padding_below, ng_padding_above); + paddle::platform::SetOutputNode(op, "Out", pool2d, ngb_node_map); + } else if (pooling_type == "avg") { + std::shared_ptr pool2d; + if (op_attrs.Get("adaptive")) { + auto ComputeAdaptive = [](size_t in, size_t k) { + return std::floor(in / k); + }; + ng_strides[0] = x_shape.size() == 4 + ? ComputeAdaptive(x_shape[3], ksize[0]) + : ng_strides[0]; + ng_strides[1] = x_shape.size() == 4 + ? ComputeAdaptive(x_shape[3], ksize[0]) + : ng_strides[1]; + pool2d = + std::make_shared(x, ng_ksize_shape, ng_strides); + } else { + pool2d = std::make_shared( + x, ng_ksize_shape, ng_strides, ng_padding_below, ng_padding_above, + !padding_exclusive); + } + paddle::platform::SetOutputNode(op, "Out", pool2d, ngb_node_map); + } else { + PADDLE_THROW("Support max and avg pooling only"); + } +} + +void BuildPool2dGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map); + auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto x_shape = x->get_shape(); + + std::string pooling_type = op_attrs.Get("pooling_type"); + std::vector ksize = op_attrs.Get>("ksize"); + std::vector strides = op_attrs.Get>("strides"); + std::vector paddings = op_attrs.Get>("paddings"); + + PADDLE_ENFORCE_EQ(x_shape.size() - 2, ksize.size(), + "Handling 2d pooling only"); + + if (op_attrs.Get("global_pooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(x_shape.at(i + 2)); + } + } + + ngraph::Shape ng_padding_below{static_cast(paddings.at(0)), + static_cast(paddings.at(1))}; + ngraph::Shape ng_padding_above{static_cast(paddings.at(0)), + static_cast(paddings.at(1))}; + ngraph::Shape ng_ksize_shape{static_cast(ksize.at(0)), + static_cast(ksize.at(1))}; + ngraph::Strides ng_strides{static_cast(strides.at(0)), + static_cast(strides.at(1))}; + + bool padding_exclusive = op_attrs.Get("exclusive"); + if (pooling_type == "max") { + auto pool2d_grad = std::make_shared( + x, dout, out, ng_ksize_shape, ng_strides, ng_padding_below, + ng_padding_above); + paddle::platform::SetOutputNode(op, "X@GRAD", pool2d_grad, ngb_node_map); + } else if (pooling_type == "avg") { + std::shared_ptr pool2d_grad; + if (op_attrs.Get("adaptive")) { + auto ComputeAdaptive = [](size_t in, size_t k) { + return std::floor(in / k); + }; + ng_strides[0] = x_shape.size() == 4 + ? ComputeAdaptive(x_shape[3], ksize[0]) + : ng_strides[0]; + ng_strides[1] = x_shape.size() == 4 + ? ComputeAdaptive(x_shape[3], ksize[0]) + : ng_strides[1]; + pool2d_grad = std::make_shared( + x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below, + ng_padding_above, !padding_exclusive); + } else { + pool2d_grad = std::make_shared( + x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below, + ng_padding_above, !padding_exclusive); + } + paddle::platform::SetOutputNode(op, "X@GRAD", pool2d_grad, ngb_node_map); + } else { + PADDLE_THROW("Support max and avg pooling only"); + } +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_pool2d_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_pool2d_ngraph_op.py new file mode 100644 index 0000000000..95e592e8ec --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_pool2d_ngraph_op.py @@ -0,0 +1,51 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid.tests.unittests.test_pool2d_op import * + + +class TestNGRAPHPool2D_Op(TestPool2D_Op): + def init_test_case(self): + super(TestNGRAPHPool2D_Op, self).init_test_case() + + +class TestNGRAPHCase1(TestCase1): + def init_test_case(self): + super(TestNGRAPHCase1, self).init_test_case() + + +class TestNGRAPHCase2(TestCase2): + def init_test_case(self): + super(TestNGRAPHCase2, self).init_test_case() + + +class TestNGRAPHCase3(TestCase3): + def init_pool_type(self): + super(TestNGRAPHCase3, self).init_pool_type() + + +class TestNGRAPHCase4(TestCase4): + def init_pool_type(self): + super(TestNGRAPHCase4, self).init_pool_type() + + +class TestNGRAPHCase5(TestCase5): + def init_pool_type(self): + super(TestNGRAPHCase5, self).init_pool_type() + + +if __name__ == '__main__': + unittest.main() From 1b8047b712c58b751b627faff486a613e2058bf5 Mon Sep 17 00:00:00 2001 From: Haihao Shen Date: Wed, 30 Jan 2019 14:57:24 +0800 Subject: [PATCH 53/53] Add INT8 calibration support in Paddle package (#15569) * Add INT8 calibration support in Paddle package; test=develop --- paddle/fluid/API.spec | 3 +++ python/paddle/fluid/contrib/__init__.py | 3 +++ .../fluid/contrib/int8_inference/__init__.py | 7 +++++++ .../fluid/contrib/int8_inference/utility.py | 17 ++++++++++------- .../fluid/contrib/tests/test_calibration.py | 3 +-- python/setup.py.in | 1 + 6 files changed, 25 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index fe8d6dd425..b793bb23fc 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -361,6 +361,9 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) +paddle.fluid.contrib.Calibrator.__init__ ArgSpec(args=['self'], varargs='args', keywords='kwargs', defaults=None) +paddle.fluid.contrib.Calibrator.sample_data ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.Calibrator.save_int8_model ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.reader.ctr_reader.ctr_reader ArgSpec(args=['feed_dict', 'file_type', 'file_format', 'dense_slot_index', 'sparse_slot_index', 'capacity', 'thread_num', 'batch_size', 'file_list', 'slots', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.build_compressor ArgSpec(args=['place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'config'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) paddle.fluid.contrib.CompressPass.__init__ ArgSpec(args=['self', 'place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'program_exe'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 6127ca8a3e..870c57e540 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -22,6 +22,8 @@ from . import op_frequence from .op_frequence import * from . import quantize from .quantize import * +from . import int8_inference +from .int8_inference import * from . import reader from .reader import * from . import slim @@ -34,6 +36,7 @@ __all__ += decoder.__all__ __all__ += memory_usage_calc.__all__ __all__ += op_frequence.__all__ __all__ += quantize.__all__ +__all__ += int8_inference.__all__ __all__ += reader.__all__ __all__ += slim.__all__ __all__ += utils.__all__ diff --git a/python/paddle/fluid/contrib/int8_inference/__init__.py b/python/paddle/fluid/contrib/int8_inference/__init__.py index eca2dce114..45547201d5 100644 --- a/python/paddle/fluid/contrib/int8_inference/__init__.py +++ b/python/paddle/fluid/contrib/int8_inference/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import print_function + +from . import utility +from .utility import * + +__all__ = utility.__all__ diff --git a/python/paddle/fluid/contrib/int8_inference/utility.py b/python/paddle/fluid/contrib/int8_inference/utility.py index 40de038f28..b35d9f2424 100644 --- a/python/paddle/fluid/contrib/int8_inference/utility.py +++ b/python/paddle/fluid/contrib/int8_inference/utility.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import paddle.fluid.core as core + +from paddle.fluid import core import numpy as np import math import os -import paddle.fluid as fluid +from paddle.fluid.executor import global_scope +from paddle.fluid import io + +__all__ = ['Calibrator'] class Calibrator(object): @@ -76,8 +80,7 @@ class Calibrator(object): ''' for i in self.sampling_program.list_vars(): if i.name in self.sampling_vars: - np_data = np.array(fluid.global_scope().find_var(i.name) - .get_tensor()) + np_data = np.array(global_scope().find_var(i.name).get_tensor()) if i.name not in self._sampling_data: self._sampling_data[i.name] = [] self._sampling_data[i.name].append(np_data) @@ -86,9 +89,9 @@ class Calibrator(object): ''' Save the quantized model to the disk. ''' - fluid.io.save_inference_model(self.output, self.feed_var_names, - self.fetch_list, self.exe, - self.sampling_program) + io.save_inference_model(self.output, self.feed_var_names, + self.fetch_list, self.exe, + self.sampling_program) def __display_debug(self): if self.debug: diff --git a/python/paddle/fluid/contrib/tests/test_calibration.py b/python/paddle/fluid/contrib/tests/test_calibration.py index cd6b7ba166..424ea245a0 100644 --- a/python/paddle/fluid/contrib/tests/test_calibration.py +++ b/python/paddle/fluid/contrib/tests/test_calibration.py @@ -24,8 +24,7 @@ import contextlib from paddle.dataset.common import download from PIL import Image, ImageEnhance import math -sys.path.append('..') -import int8_inference.utility as int8_utility +import paddle.fluid.contrib.int8_inference.utility as int8_utility random.seed(0) np.random.seed(0) diff --git a/python/setup.py.in b/python/setup.py.in index c947785cbf..f93f0cd130 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -109,6 +109,7 @@ packages=['paddle', 'paddle.fluid.contrib', 'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.quantize', + 'paddle.fluid.contrib.int8_inference', 'paddle.fluid.contrib.reader', 'paddle.fluid.contrib.slim', 'paddle.fluid.contrib.slim.core',