From 10dd3b37ad26660bbd9c52c111039688e6b063b5 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Thu, 17 Jan 2019 12:13:34 +0000 Subject: [PATCH 1/9] add axis for box coder op --- paddle/fluid/API.spec | 2 +- .../fluid/operators/detection/box_coder_op.cc | 40 +++- .../fluid/operators/detection/box_coder_op.cu | 83 ++++++--- .../fluid/operators/detection/box_coder_op.h | 76 +++++--- python/paddle/fluid/layers/detection.py | 9 +- .../tests/unittests/test_box_coder_op.py | 176 ++++++++++++++---- 6 files changed, 282 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 50ffef72ba..7068a37ef0 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -315,7 +315,7 @@ paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'tr paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) +paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'axis', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, 0, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 06fbb9815c..5db600b19a 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -32,31 +32,53 @@ class BoxCoderOp : public framework::OperatorWithKernel { if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, - "The rank of Input of PriorBoxVar must be 2"); + "The rank of Input of PriorBox must be 2"); PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); if (ctx->HasInput("PriorBoxVar")) { auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); - PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims); + PADDLE_ENFORCE( + prior_box_var_dims.size() == 1 || prior_box_var_dims.size() == 2, + "Input(PriorBoxVar) of BoxCoderOp should be 1 or 2."); + if (prior_box_var_dims.size() == 1) { + PADDLE_ENFORCE_EQ( + prior_box_var_dims[0], 4, + "The 1st dimension of Input(PriorBoxVar) should be 1" + "when the rank is 1."); + } else { + PADDLE_ENFORCE_EQ( + prior_box_dims, prior_box_var_dims, + "The dimension of Input(PriorBoxVar) should be equal to" + "the dimension of Input(PriorBox when the rank is 2.)"); + } } auto code_type = GetBoxCodeType(ctx->Attrs().Get("code_type")); + int axis = ctx->Attrs().Get("axis"); if (code_type == BoxCodeType::kEncodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, "The rank of Input of TargetBox must be 2"); PADDLE_ENFORCE_EQ(target_box_dims[1], 4, "The shape of TargetBox is [M, 4]"); + ctx->SetOutputDim( + "OutputBox", + framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); } else if (code_type == BoxCodeType::kDecodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, "The rank of Input of TargetBox must be 3"); - PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + if (axis == 0) { + PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + } else if (axis == 1) { + PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); + } else { + PADDLE_THROW("axis must be 0 or 1."); + } PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); + ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); } } - ctx->SetOutputDim( - "OutputBox", - framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); } }; @@ -100,6 +122,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default true) " "whether treat the priorbox as a noramlized box") .SetDefault(true); + AddAttr("axis", + "(int, default 1)" + "which axis to broadcast for box decode, it is only valid" + "when code type is decode_center_size") + .SetDefault(0) + .InEnum({0, 1}); AddOutput("OutputBox", "(LoDTensor or Tensor) " "When code_type is 'encode_center_size', the output tensor of " diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index a7af111f63..ca62afd8ed 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -20,7 +20,8 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, - const bool normalized, T* output) { + const bool normalized, + const T prior_box_var_size, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; @@ -30,11 +31,9 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, T prior_box_height = prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1] + (normalized == false); - T prior_box_center_x = - (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; - T prior_box_center_y = (prior_box_data[col_idx * len + 3] + - prior_box_data[col_idx * len + 1]) / - 2; + T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2; + T prior_box_center_y = + prior_box_data[col_idx * len + 1] + prior_box_height / 2; T target_box_center_x = (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / @@ -55,10 +54,14 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)); output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)); if (prior_box_var_data) { - output[idx * len] /= prior_box_var_data[col_idx * len]; - output[idx * len + 1] /= prior_box_var_data[col_idx * len + 1]; - output[idx * len + 2] /= prior_box_var_data[col_idx * len + 2]; - output[idx * len + 3] /= prior_box_var_data[col_idx * len + 3]; + int prior_var_offset = 0; + if (prior_box_var_size == 2) { + prior_var_offset = col_idx * len; + } + output[idx * len] /= prior_box_var_data[prior_var_offset]; + output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; + output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; + output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; } } } @@ -68,33 +71,48 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, - const bool normalized, T* output) { + const bool normalized, + const T prior_box_var_size, + const int axis, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; + int prior_box_offset = 0; if (idx < row * col) { const int col_idx = idx % col; - T prior_box_width = prior_box_data[col_idx * len + 2] - - prior_box_data[col_idx * len] + (normalized == false); - T prior_box_height = prior_box_data[col_idx * len + 3] - - prior_box_data[col_idx * len + 1] + + const int row_idx = idx / col; + if (axis == 0) + prior_box_offset = col_idx * len; + else if (axis == 1) + prior_box_offset = row_idx * len; + T prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + T prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + (normalized == false); T prior_box_center_x = - (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; - T prior_box_center_y = (prior_box_data[col_idx * len + 3] + - prior_box_data[col_idx * len + 1]) / - 2; + prior_box_data[prior_box_offset] + prior_box_width / 2; + T prior_box_center_y = + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; T target_box_width, target_box_height; T target_box_center_x, target_box_center_y; if (prior_box_var_data) { - target_box_width = exp(prior_box_var_data[col_idx * len + 2] * + int prior_var_offset = 0; + if (prior_box_var_size == 2) { + if (axis == 0) + prior_var_offset = col_idx * len; + else if (axis == 1) + prior_var_offset = row_idx * len; + } + target_box_width = exp(prior_box_var_data[prior_var_offset + 2] * target_box_data[idx * len + 2]) * prior_box_width; - target_box_height = exp(prior_box_var_data[col_idx * len + 3] * + target_box_height = exp(prior_box_var_data[prior_var_offset + 3] * target_box_data[idx * len + 3]) * prior_box_height; - target_box_center_x = prior_box_var_data[col_idx * len] * + target_box_center_x = prior_box_var_data[prior_var_offset] * target_box_data[idx * len] * prior_box_width + prior_box_center_x; - target_box_center_y = prior_box_var_data[col_idx * len + 1] * + target_box_center_y = prior_box_var_data[prior_var_offset + 1] * target_box_data[idx * len + 1] * prior_box_height + prior_box_center_y; @@ -131,14 +149,25 @@ class BoxCoderCUDAKernel : public framework::OpKernel { const T* prior_box_data = prior_box->data(); const T* target_box_data = target_box->data(); const T* prior_box_var_data = nullptr; - if (prior_box_var) prior_box_var_data = prior_box_var->data(); + auto prior_box_var_size = 0; + if (prior_box_var) { + prior_box_var_data = prior_box_var->data(); + prior_box_var_size = prior_box_var->dims().size(); + } if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } + auto code_type = GetBoxCodeType(context.Attr("code_type")); + bool normalized = context.Attr("box_normalized"); + int axis = context.Attr("axis"); + auto row = target_box->dims()[0]; auto col = prior_box->dims()[0]; + if (code_type == BoxCodeType::kDecodeCenterSize) { + col = target_box->dims()[1]; + } auto len = prior_box->dims()[1]; int block = 512; int grid = (row * col + block - 1) / block; @@ -147,16 +176,14 @@ class BoxCoderCUDAKernel : public framework::OpKernel { output_box->mutable_data({row, col, len}, context.GetPlace()); T* output = output_box->data(); - auto code_type = GetBoxCodeType(context.Attr("code_type")); - bool normalized = context.Attr("box_normalized"); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, output); + normalized, prior_box_var_size, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, output); + normalized, prior_box_var_size, axis, output); } } }; diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index b2a2bcdce9..986869d8a3 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -53,10 +53,9 @@ class BoxCoderKernel : public framework::OpKernel { T prior_box_height = prior_box_data[j * len + 3] - prior_box_data[j * len + 1] + (normalized == false); - T prior_box_center_x = - (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; + T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2; T prior_box_center_y = - (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; + prior_box_data[j * len + 1] + prior_box_height / 2; T target_box_center_x = (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; @@ -78,10 +77,14 @@ class BoxCoderKernel : public framework::OpKernel { output[offset + 3] = std::log(std::fabs(target_box_height / prior_box_height)); if (prior_box_var) { - output[offset] /= prior_box_var_data[j * len]; - output[offset + 1] /= prior_box_var_data[j * len + 1]; - output[offset + 2] /= prior_box_var_data[j * len + 2]; - output[offset + 3] /= prior_box_var_data[j * len + 3]; + int prior_var_offset = 0; + if (prior_box_var->dims().size() == 2) { + prior_var_offset = j * len; + } + output[offset] /= prior_box_var_data[prior_var_offset]; + output[offset + 1] /= prior_box_var_data[prior_var_offset + 1]; + output[offset + 2] /= prior_box_var_data[prior_var_offset + 2]; + output[offset + 3] /= prior_box_var_data[prior_var_offset + 3]; } } } @@ -89,48 +92,63 @@ class BoxCoderKernel : public framework::OpKernel { void DecodeCenterSize(const framework::Tensor* target_box, const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, - const bool normalized, T* output) const { + const bool normalized, const int axis, + T* output) const { int64_t row = target_box->dims()[0]; - int64_t col = prior_box->dims()[0]; - int64_t len = prior_box->dims()[1]; + int64_t col = target_box->dims()[1]; + int64_t len = target_box->dims()[2]; auto* target_box_data = target_box->data(); auto* prior_box_data = prior_box->data(); const T* prior_box_var_data = nullptr; if (prior_box_var) prior_box_var_data = prior_box_var->data(); - + int prior_box_offset = 0; #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(2) #endif for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { size_t offset = i * col * len + j * len; - T prior_box_width = prior_box_data[j * len + 2] - - prior_box_data[j * len] + (normalized == false); - T prior_box_height = prior_box_data[j * len + 3] - - prior_box_data[j * len + 1] + + if (axis == 0) { + prior_box_offset = j * len; + } else if (axis == 1) { + prior_box_offset = i * len; + } + T prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + T prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + (normalized == false); T prior_box_center_x = - (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; + prior_box_data[prior_box_offset] + prior_box_width / 2; T prior_box_center_y = - (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; T target_box_center_x = 0, target_box_center_y = 0; T target_box_width = 0, target_box_height = 0; if (prior_box_var) { - target_box_center_x = prior_box_var_data[j * len] * + int prior_var_offset = 0; + if (prior_box_var->dims().size() == 2) { + if (axis == 0) + prior_var_offset = j * len; + else if (axis == 1) + prior_var_offset = i * len; + } + target_box_center_x = prior_box_var_data[prior_var_offset] * target_box_data[offset] * prior_box_width + prior_box_center_x; - target_box_center_y = prior_box_var_data[j * len + 1] * + target_box_center_y = prior_box_var_data[prior_var_offset + 1] * target_box_data[offset + 1] * prior_box_height + prior_box_center_y; - target_box_width = std::exp(prior_box_var_data[j * len + 2] * + target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] * target_box_data[offset + 2]) * prior_box_width; - target_box_height = std::exp(prior_box_var_data[j * len + 3] * - target_box_data[offset + 3]) * - prior_box_height; + target_box_height = + std::exp(prior_box_var_data[prior_var_offset + 3] * + target_box_data[offset + 3]) * + prior_box_height; } else { target_box_center_x = target_box_data[offset] * prior_box_width + prior_box_center_x; @@ -157,25 +175,29 @@ class BoxCoderKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); - + const int axis = context.Attr("axis"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, "Only support 1 level of LoD."); } + auto code_type = GetBoxCodeType(context.Attr("code_type")); + bool normalized = context.Attr("box_normalized"); + auto row = target_box->dims()[0]; auto col = prior_box->dims()[0]; + if (code_type == BoxCodeType::kDecodeCenterSize) { + col = target_box->dims()[1]; + } auto len = prior_box->dims()[1]; output_box->mutable_data({row, col, len}, context.GetPlace()); - auto code_type = GetBoxCodeType(context.Attr("code_type")); - bool normalized = context.Attr("box_normalized"); T* output = output_box->data(); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { - DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, + DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis, output); } } diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 8aed97dc59..c844050c5d 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -342,6 +342,7 @@ def box_coder(prior_box, target_box, code_type="encode_center_size", box_normalized=True, + axis=0, name=None): """ ${comment} @@ -352,6 +353,7 @@ def box_coder(prior_box, target_box(${target_box_type}): ${target_box_comment} code_type(${code_type_type}): ${code_type_comment} box_normalized(${box_normalized_type}): ${box_normalized_comment} + axis(${axis_type}): ${axis_comment} Returns: output_box(${output_box_type}): ${output_box_comment} @@ -372,8 +374,11 @@ def box_coder(prior_box, "PriorBoxVar": prior_box_var, "TargetBox": target_box }, - attrs={"code_type": code_type, - "box_normalized": box_normalized}, + attrs={ + "code_type": code_type, + "box_normalized": box_normalized, + "axis": axis + }, outputs={"OutputBox": output_box}) return output_box diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 2511c5c22e..b6f6bc1450 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -21,22 +21,32 @@ import math from op_test import OpTest -def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, - box_normalized): - prior_box_x = ( - (prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0]) - prior_box_y = ( - (prior_box[:, 3] + prior_box[:, 1]) / 2).reshape(1, prior_box.shape[0]) - prior_box_width = ( - (prior_box[:, 2] - prior_box[:, 0])).reshape(1, prior_box.shape[0]) - prior_box_height = ( - (prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0]) - prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], - prior_box_var.shape[1]) - if not box_normalized: - prior_box_height = prior_box_height + 1 - prior_box_width = prior_box_width + 1 - +def box_coder(target_box, + prior_box, + prior_box_var, + output_box, + code_type, + box_normalized, + axis=0): + prior_box_width = prior_box[:, 2] - prior_box[:, 0] + \ + (box_normalized==False) + prior_box_height = prior_box[:, 3] - prior_box[:, 1] + \ + (box_normalized==False) + prior_box_x = prior_box_width * 0.5 + prior_box[:, 0] + prior_box_y = prior_box_height * 0.5 + prior_box[:, 1] + if axis == 0: + prior_box_width = prior_box_width.reshape(1, prior_box.shape[0]) + prior_box_height = prior_box_height.reshape(1, prior_box.shape[0]) + prior_box_x = prior_box_x.reshape(1, prior_box.shape[0]) + prior_box_y = prior_box_y.reshape(1, prior_box.shape[0]) + else: + prior_box_width = prior_box_width.reshape(prior_box.shape[0], 1) + prior_box_height = prior_box_height.reshape(prior_box.shape[0], 1) + prior_box_x = prior_box_x.reshape(prior_box.shape[0], 1) + prior_box_y = prior_box_y.reshape(prior_box.shape[0], 1) + if prior_box_var.ndim == 2: + prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], + prior_box_var.shape[1]) if (code_type == "EncodeCenterSize"): target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape( target_box.shape[0], 1) @@ -49,26 +59,52 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, if not box_normalized: target_box_height = target_box_height + 1 target_box_width = target_box_width + 1 - - output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \ - prior_box_var[:,:,0] - output_box[:,:,1] = (target_box_y - prior_box_y) / prior_box_height / \ - prior_box_var[:,:,1] - output_box[:,:,2] = np.log(np.fabs(target_box_width / prior_box_width)) / \ - prior_box_var[:,:,2] - output_box[:,:,3] = np.log(np.fabs(target_box_height / prior_box_height)) / \ - prior_box_var[:,:,3] + if prior_box_var.ndim == 1: + output_box[:,:,0] = (target_box_x - prior_box_x) / \ + prior_box_width / \ + prior_box_var[0] + output_box[:,:,1] = (target_box_y - prior_box_y) / \ + prior_box_height / \ + prior_box_var[1] + output_box[:,:,2] = np.log(np.fabs(target_box_width / \ + prior_box_width)) / \ + prior_box_var[2] + output_box[:,:,3] = np.log(np.fabs(target_box_height / \ + prior_box_height)) / \ + prior_box_var[3] + else: + output_box[:,:,0] = (target_box_x - prior_box_x) / \ + prior_box_width / \ + prior_box_var[:,:,0] + output_box[:,:,1] = (target_box_y - prior_box_y) / \ + prior_box_height / \ + prior_box_var[:,:,1] + output_box[:,:,2] = np.log(np.fabs(target_box_width / \ + prior_box_width)) / \ + prior_box_var[:,:,2] + output_box[:,:,3] = np.log(np.fabs(target_box_height / \ + prior_box_height)) / \ + prior_box_var[:,:,3] elif (code_type == "DecodeCenterSize"): - target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ - prior_box_width + prior_box_x - target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ - prior_box_height + prior_box_y - target_box_width = np.exp(prior_box_var[:,:,2] * target_box[:,:,2]) * \ - prior_box_width - target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \ - prior_box_height - + if prior_box_var.ndim == 1: + target_box_x = prior_box_var[0] * target_box[:,:,0] * \ + prior_box_width + prior_box_x + target_box_y = prior_box_var[1] * target_box[:,:,1] * \ + prior_box_height + prior_box_y + target_box_width = np.exp(prior_box_var[2] * target_box[:,:,2]) * \ + prior_box_width + target_box_height = np.exp(prior_box_var[3] * target_box[:,:,3]) * \ + prior_box_height + else: + target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ + prior_box_width + prior_box_x + target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ + prior_box_height + prior_box_y + target_box_width = np.exp(prior_box_var[:,:,2] * \ + target_box[:,:,2]) * prior_box_width + target_box_height = np.exp(prior_box_var[:,:,3] * \ + target_box[:,:,3]) * prior_box_height output_box[:, :, 0] = target_box_x - target_box_width / 2 output_box[:, :, 1] = target_box_y - target_box_height / 2 output_box[:, :, 2] = target_box_x + target_box_width / 2 @@ -78,10 +114,17 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, output_box[:, :, 3] = output_box[:, :, 3] - 1 -def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type, - box_normalized): +def batch_box_coder(prior_box, + prior_box_var, + target_box, + lod, + code_type, + box_normalized, + axis=0): n = target_box.shape[0] m = prior_box.shape[0] + if code_type == "DecodeCenterSize": + m = target_box.shape[1] output_box = np.zeros((n, m, 4), dtype=np.float32) cur_offset = 0 for i in range(len(lod)): @@ -91,10 +134,8 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type, output_box[cur_offset:(cur_offset + lod[i]), :, :], code_type, box_normalized) elif (code_type == "DecodeCenterSize"): - box_coder(target_box[cur_offset:(cur_offset + lod[i]), :, :], - prior_box, prior_box_var, - output_box[cur_offset:(cur_offset + lod[i]), :, :], - code_type, box_normalized) + box_coder(target_box, prior_box, prior_box_var, output_box, + code_type, box_normalized, axis) cur_offset += lod[i] return output_box @@ -111,6 +152,32 @@ class TestBoxCoderOp(OpTest): target_box = np.random.random((5, 10, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type, box_normalized) + self.inputs = { + 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, + 'TargetBox': target_box, + } + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False + } + self.outputs = {'OutputBox': output_box} + + +class TestBoxCoderOpWithOneRankVar(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[1, 1, 1, 1, 1]] + prior_box = np.random.random((6, 4)).astype('float32') + prior_box_var = np.random.random((4)).astype('float32') + target_box = np.random.random((3, 6, 4)).astype('float32') + code_type = "DecodeCenterSize" + box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, lod[0], code_type, box_normalized) @@ -176,5 +243,34 @@ class TestBoxCoderOpWithLoD(OpTest): self.outputs = {'OutputBox': output_box} +class TestBoxCoderOpWithAxis(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[1, 1, 1, 1, 1]] + prior_box = np.random.random((5, 4)).astype('float32') + prior_box_var = np.random.random((4)).astype('float32') + target_box = np.random.random((5, 6, 4)).astype('float32') + code_type = "DecodeCenterSize" + box_normalized = False + axis = 1 + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type, box_normalized, axis) + + self.inputs = { + 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, + 'TargetBox': target_box, + } + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False, + 'axis': axis + } + self.outputs = {'OutputBox': output_box} + + if __name__ == '__main__': unittest.main() From ab9d6a4f39ee8fefceb7392f1b93131eed8db9dc Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Thu, 17 Jan 2019 12:20:18 +0000 Subject: [PATCH 2/9] add comments, test=develop --- paddle/fluid/operators/detection/box_coder_op.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 5db600b19a..e342417491 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -166,7 +166,11 @@ where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the -encoded/decoded coordinates, width and height. +encoded/decoded coordinates, width and height. + +During Box Decoding, two modes for broadcast are supported. Say target box has +shape [N, M, 4], and the shape of prior box can be [N, 4] or [M, 4]. Then prior +box will broadcast to target box along the assigned axis. )DOC"); } }; From 0d915078597f483057b25cdc2e99bdd9bee71f71 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 05:22:47 +0000 Subject: [PATCH 3/9] fix share lod, test=develop --- paddle/fluid/operators/detection/box_coder_op.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index e342417491..b4b02124cc 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -77,9 +77,13 @@ class BoxCoderOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); } - } - ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { + ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + } else { + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + } + } } }; From 66bb5dd760f0ce72740ca755224bb3ca85194600 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 10:18:41 +0000 Subject: [PATCH 4/9] refine infer shape, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index b4b02124cc..2ce844669b 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -43,7 +43,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { if (prior_box_var_dims.size() == 1) { PADDLE_ENFORCE_EQ( prior_box_var_dims[0], 4, - "The 1st dimension of Input(PriorBoxVar) should be 1" + "The 1st dimension of Input(PriorBoxVar) should be 4" "when the rank is 1."); } else { PADDLE_ENFORCE_EQ( @@ -52,37 +52,36 @@ class BoxCoderOp : public framework::OperatorWithKernel { "the dimension of Input(PriorBox when the rank is 2.)"); } } + } - auto code_type = - GetBoxCodeType(ctx->Attrs().Get("code_type")); - int axis = ctx->Attrs().Get("axis"); - if (code_type == BoxCodeType::kEncodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, - "The rank of Input of TargetBox must be 2"); - PADDLE_ENFORCE_EQ(target_box_dims[1], 4, - "The shape of TargetBox is [M, 4]"); - ctx->SetOutputDim( - "OutputBox", - framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); - } else if (code_type == BoxCodeType::kDecodeCenterSize) { - PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, - "The rank of Input of TargetBox must be 3"); - if (axis == 0) { - PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); - } else if (axis == 1) { - PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); - } else { - PADDLE_THROW("axis must be 0 or 1."); - } - PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); - ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); - } - - if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { - ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + auto code_type = GetBoxCodeType(ctx->Attrs().Get("code_type")); + int axis = ctx->Attrs().Get("axis"); + if (code_type == BoxCodeType::kEncodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, + "The rank of Input of TargetBox must be 2"); + PADDLE_ENFORCE_EQ(target_box_dims[1], 4, + "The shape of TargetBox is [M, 4]"); + ctx->SetOutputDim( + "OutputBox", + framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); + } else if (code_type == BoxCodeType::kDecodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, + "The rank of Input of TargetBox must be 3"); + if (axis == 0) { + PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + } else if (axis == 1) { + PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]); } else { - ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + PADDLE_THROW("axis must be 0 or 1."); } + PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); + ctx->ShareDim("TargetBox", /*->*/ "OutputBox"); + } + + if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) { + ctx->ShareLoD("PriorBox", /*->*/ "OutputBox"); + } else { + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); } } }; From 0d4b60ab8bc8d1db9fdef1a6228663c3f60a3980 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 12:25:07 +0000 Subject: [PATCH 5/9] add lod for slice op, test=develop --- paddle/fluid/operators/slice_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 789e61b2d3..94995fc996 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -54,6 +54,9 @@ class SliceOp : public framework::OperatorWithKernel { out_dims[axes[i]] = end - start; } ctx->SetOutputDim("Out", out_dims); + if (axes[0] != 0) { + ctx->ShareLoD("Input", /*->*/ "Out"); + } } protected: From c12a969bd446691d107ab1607be529ef9388bcd0 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Tue, 22 Jan 2019 13:27:21 +0000 Subject: [PATCH 6/9] refine comment and unittest, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 13 +- .../fluid/operators/detection/box_coder_op.cu | 10 +- python/paddle/fluid/layers/detection.py | 4 +- .../tests/unittests/test_box_coder_op.py | 175 +++++++----------- 4 files changed, 79 insertions(+), 123 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 2ce844669b..f89f87663b 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -32,7 +32,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, - "The rank of Input of PriorBox must be 2"); + "The rank of Input PriorBox must be 2"); PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); if (ctx->HasInput("PriorBoxVar")) { @@ -58,7 +58,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { int axis = ctx->Attrs().Get("axis"); if (code_type == BoxCodeType::kEncodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, - "The rank of Input of TargetBox must be 2"); + "The rank of Input TargetBox must be 2"); PADDLE_ENFORCE_EQ(target_box_dims[1], 4, "The shape of TargetBox is [M, 4]"); ctx->SetOutputDim( @@ -66,7 +66,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); } else if (code_type == BoxCodeType::kDecodeCenterSize) { PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, - "The rank of Input of TargetBox must be 3"); + "The rank of Input TargetBox must be 3"); if (axis == 0) { PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); } else if (axis == 1) { @@ -126,8 +126,11 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "whether treat the priorbox as a noramlized box") .SetDefault(true); AddAttr("axis", - "(int, default 1)" - "which axis to broadcast for box decode, it is only valid" + "(int, default 0)" + "which axis in PriorBox to broadcast for box decode," + "for example, if axis is 0 and TargetBox has shape" + "[N, M, 4] and PriorBox has shape [M, 4], then PriorBox " + "will broadcast to [N, M, 4] for decoding. It is only valid" "when code type is decode_center_size") .SetDefault(0) .InEnum({0, 1}); diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index ca62afd8ed..0b64224e1e 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -79,10 +79,7 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, if (idx < row * col) { const int col_idx = idx % col; const int row_idx = idx / col; - if (axis == 0) - prior_box_offset = col_idx * len; - else if (axis == 1) - prior_box_offset = row_idx * len; + prior_box_offset = axis == 0 ? col_idx * len : row_idx * len; T prior_box_width = prior_box_data[prior_box_offset + 2] - prior_box_data[prior_box_offset] + (normalized == false); @@ -98,10 +95,7 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, if (prior_box_var_data) { int prior_var_offset = 0; if (prior_box_var_size == 2) { - if (axis == 0) - prior_var_offset = col_idx * len; - else if (axis == 1) - prior_var_offset = row_idx * len; + prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; } target_box_width = exp(prior_box_var_data[prior_var_offset + 2] * target_box_data[idx * len + 2]) * diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index c844050c5d..8c8a6c6223 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -342,8 +342,8 @@ def box_coder(prior_box, target_box, code_type="encode_center_size", box_normalized=True, - axis=0, - name=None): + name=None, + axis=0): """ ${comment} diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index b6f6bc1450..6f7930c921 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -21,121 +21,80 @@ import math from op_test import OpTest -def box_coder(target_box, - prior_box, - prior_box_var, - output_box, - code_type, - box_normalized, - axis=0): - prior_box_width = prior_box[:, 2] - prior_box[:, 0] + \ - (box_normalized==False) - prior_box_height = prior_box[:, 3] - prior_box[:, 1] + \ - (box_normalized==False) - prior_box_x = prior_box_width * 0.5 + prior_box[:, 0] - prior_box_y = prior_box_height * 0.5 + prior_box[:, 1] - if axis == 0: - prior_box_width = prior_box_width.reshape(1, prior_box.shape[0]) - prior_box_height = prior_box_height.reshape(1, prior_box.shape[0]) - prior_box_x = prior_box_x.reshape(1, prior_box.shape[0]) - prior_box_y = prior_box_y.reshape(1, prior_box.shape[0]) +def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0): + pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False) + pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False) + pb_x = pb_w * 0.5 + p_box[:, 0] + pb_y = pb_h * 0.5 + p_box[:, 1] + shape = (1, p_box.shape[0]) if axis == 0 else (p_box.shape[0], 1) + + pb_w = pb_w.reshape(shape) + pb_h = pb_h.reshape(shape) + pb_x = pb_x.reshape(shape) + pb_y = pb_y.reshape(shape) + + if pb_v.ndim == 2: + pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1]) + if pb_v.ndim == 1: + tb_x = pb_v[0] * t_box[:, :, 0] * pb_w + pb_x + tb_y = pb_v[1] * t_box[:, :, 1] * pb_h + pb_y + tb_w = np.exp(pb_v[2] * t_box[:, :, 2]) * pb_w + tb_h = np.exp(pb_v[3] * t_box[:, :, 3]) * pb_h else: - prior_box_width = prior_box_width.reshape(prior_box.shape[0], 1) - prior_box_height = prior_box_height.reshape(prior_box.shape[0], 1) - prior_box_x = prior_box_x.reshape(prior_box.shape[0], 1) - prior_box_y = prior_box_y.reshape(prior_box.shape[0], 1) - if prior_box_var.ndim == 2: - prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], - prior_box_var.shape[1]) - if (code_type == "EncodeCenterSize"): - target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape( - target_box.shape[0], 1) - target_box_y = ((target_box[:, 3] + target_box[:, 1]) / 2).reshape( - target_box.shape[0], 1) - target_box_width = ((target_box[:, 2] - target_box[:, 0])).reshape( - target_box.shape[0], 1) - target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape( - target_box.shape[0], 1) - if not box_normalized: - target_box_height = target_box_height + 1 - target_box_width = target_box_width + 1 - if prior_box_var.ndim == 1: - output_box[:,:,0] = (target_box_x - prior_box_x) / \ - prior_box_width / \ - prior_box_var[0] - output_box[:,:,1] = (target_box_y - prior_box_y) / \ - prior_box_height / \ - prior_box_var[1] - output_box[:,:,2] = np.log(np.fabs(target_box_width / \ - prior_box_width)) / \ - prior_box_var[2] - output_box[:,:,3] = np.log(np.fabs(target_box_height / \ - prior_box_height)) / \ - prior_box_var[3] - else: - output_box[:,:,0] = (target_box_x - prior_box_x) / \ - prior_box_width / \ - prior_box_var[:,:,0] - output_box[:,:,1] = (target_box_y - prior_box_y) / \ - prior_box_height / \ - prior_box_var[:,:,1] - output_box[:,:,2] = np.log(np.fabs(target_box_width / \ - prior_box_width)) / \ - prior_box_var[:,:,2] - output_box[:,:,3] = np.log(np.fabs(target_box_height / \ - prior_box_height)) / \ - prior_box_var[:,:,3] - - elif (code_type == "DecodeCenterSize"): - if prior_box_var.ndim == 1: - target_box_x = prior_box_var[0] * target_box[:,:,0] * \ - prior_box_width + prior_box_x - target_box_y = prior_box_var[1] * target_box[:,:,1] * \ - prior_box_height + prior_box_y - target_box_width = np.exp(prior_box_var[2] * target_box[:,:,2]) * \ - prior_box_width - target_box_height = np.exp(prior_box_var[3] * target_box[:,:,3]) * \ - prior_box_height - else: - target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ - prior_box_width + prior_box_x - target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ - prior_box_height + prior_box_y - target_box_width = np.exp(prior_box_var[:,:,2] * \ - target_box[:,:,2]) * prior_box_width - target_box_height = np.exp(prior_box_var[:,:,3] * \ - target_box[:,:,3]) * prior_box_height - output_box[:, :, 0] = target_box_x - target_box_width / 2 - output_box[:, :, 1] = target_box_y - target_box_height / 2 - output_box[:, :, 2] = target_box_x + target_box_width / 2 - output_box[:, :, 3] = target_box_y + target_box_height / 2 - if not box_normalized: - output_box[:, :, 2] = output_box[:, :, 2] - 1 - output_box[:, :, 3] = output_box[:, :, 3] - 1 - - -def batch_box_coder(prior_box, - prior_box_var, - target_box, - lod, - code_type, - box_normalized, - axis=0): - n = target_box.shape[0] - m = prior_box.shape[0] + tb_x = pb_v[:, :, 0] * t_box[:, :, 0] * pb_w + pb_x + tb_y = pb_v[:, :, 1] * t_box[:, :, 1] * pb_h + pb_y + tb_w = np.exp(pb_v[:, :, 2] * t_box[:, :, 2]) * pb_w + tb_h = np.exp(pb_v[:, :, 3] * t_box[:, :, 3]) * pb_h + output_box[:, :, 0] = tb_x - tb_w / 2 + output_box[:, :, 1] = tb_y - tb_h / 2 + output_box[:, :, 2] = tb_x + tb_w / 2 - (not norm) + output_box[:, :, 3] = tb_y + tb_h / 2 - (not norm) + + +def box_encoder(t_box, p_box, pb_v, output_box, norm): + pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False) + pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False) + pb_x = pb_w * 0.5 + p_box[:, 0] + pb_y = pb_h * 0.5 + p_box[:, 1] + shape = (1, p_box.shape[0]) + + pb_w = pb_w.reshape(shape) + pb_h = pb_h.reshape(shape) + pb_x = pb_x.reshape(shape) + pb_y = pb_y.reshape(shape) + + if pb_v.ndim == 2: + pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1]) + tb_x = ((t_box[:, 2] + t_box[:, 0]) / 2).reshape(t_box.shape[0], 1) + tb_y = ((t_box[:, 3] + t_box[:, 1]) / 2).reshape(t_box.shape[0], 1) + tb_w = (t_box[:, 2] - t_box[:, 0]).reshape(t_box.shape[0], 1) + (not norm) + tb_h = (t_box[:, 3] - t_box[:, 1]).reshape(t_box.shape[0], 1) + (not norm) + if pb_v.ndim == 1: + output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[0] + output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[1] + output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[2] + output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[3] + else: + output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[:, :, 0] + output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[:, :, 1] + output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[:, :, 2] + output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[:, :, 3] + + +def batch_box_coder(p_box, pb_v, t_box, lod, code_type, norm, axis=0): + n = t_box.shape[0] + m = p_box.shape[0] if code_type == "DecodeCenterSize": - m = target_box.shape[1] + m = t_box.shape[1] output_box = np.zeros((n, m, 4), dtype=np.float32) cur_offset = 0 for i in range(len(lod)): if (code_type == "EncodeCenterSize"): - box_coder(target_box[cur_offset:(cur_offset + lod[i]), :], - prior_box, prior_box_var, - output_box[cur_offset:(cur_offset + lod[i]), :, :], - code_type, box_normalized) + box_encoder(t_box[cur_offset:(cur_offset + lod[i]), :], p_box, pb_v, + output_box[cur_offset:(cur_offset + lod[i]), :, :], + norm) elif (code_type == "DecodeCenterSize"): - box_coder(target_box, prior_box, prior_box_var, output_box, - code_type, box_normalized, axis) + box_decoder(t_box, p_box, pb_v, output_box, norm, axis) cur_offset += lod[i] return output_box From f44b1507f0a3ab7d8aef7cd2b23b8cc90a55f355 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Wed, 23 Jan 2019 02:21:10 +0000 Subject: [PATCH 7/9] revised API spec, test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 7068a37ef0..cdb0397ecd 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -315,7 +315,7 @@ paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'tr paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'axis', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, 0, None)) +paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) From a39240c3b6af17b05e5a55bf8bbb199775498696 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Fri, 25 Jan 2019 07:46:48 +0000 Subject: [PATCH 8/9] add attr variance for box coder, test=develop --- .../fluid/operators/detection/box_coder_op.cc | 7 + .../fluid/operators/detection/box_coder_op.cu | 59 +++++--- .../fluid/operators/detection/box_coder_op.h | 38 +++++- python/paddle/fluid/layers/detection.py | 126 +++++++++++++++--- python/paddle/fluid/tests/test_detection.py | 2 +- .../tests/unittests/test_box_coder_op.py | 57 ++++++-- 6 files changed, 236 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index f89f87663b..fdcff62e1f 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detection/box_coder_op.h" +#include namespace paddle { namespace operators { @@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "when code type is decode_center_size") .SetDefault(0) .InEnum({0, 1}); + AddAttr>( + "variance", + "(vector, default {})," + "variance of prior box with shape [4]. PriorBoxVar and variance can" + "not be provided at the same time.") + .SetDefault(std::vector{}); AddOutput("OutputBox", "(LoDTensor or Tensor) " "When code_type is 'encode_center_size', the output tensor of " diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index 0b64224e1e..9b73572274 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -16,12 +18,11 @@ namespace paddle { namespace operators { template -__global__ void EncodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, const int row, - const int col, const int len, - const bool normalized, - const T prior_box_var_size, T* output) { +__global__ void EncodeCenterSizeKernel( + const T* prior_box_data, const T* prior_box_var_data, + const T* target_box_data, const int row, const int col, const int len, + const bool normalized, const T prior_box_var_size, const float* variance, + const int var_size, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; @@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3]; + } else if (var_size == 4) { + for (int k = 0; k < 4; ++k) { + output[idx * len + k] /= static_cast(variance[k]); + } } } } template -__global__ void DecodeCenterSizeKernel(const T* prior_box_data, - const T* prior_box_var_data, - const T* target_box_data, const int row, - const int col, const int len, - const bool normalized, - const T prior_box_var_size, - const int axis, T* output) { +__global__ void DecodeCenterSizeKernel( + const T* prior_box_data, const T* prior_box_var_data, + const T* target_box_data, const int row, const int col, const int len, + const bool normalized, const T prior_box_var_size, const float* variance, + const int var_size, const int axis, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; int prior_box_offset = 0; if (idx < row * col) { @@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, target_box_data[idx * len + 1] * prior_box_height + prior_box_center_y; + } else if (var_size == 4) { + target_box_width = + exp(static_cast(variance[2]) * target_box_data[idx * len + 2]) * + prior_box_width; + target_box_height = + exp(static_cast(variance[3]) * target_box_data[idx * len + 3]) * + prior_box_height; + target_box_center_x = static_cast(variance[0]) * + target_box_data[idx * len] * prior_box_width + + prior_box_center_x; + target_box_center_y = static_cast(variance[1]) * + target_box_data[idx * len + 1] * + prior_box_height + + prior_box_center_y; } else { target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width; target_box_height = @@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); - + std::vector variance = context.Attr>("variance"); const T* prior_box_data = prior_box->data(); const T* target_box_data = target_box->data(); const T* prior_box_var_data = nullptr; auto prior_box_var_size = 0; if (prior_box_var) { + PADDLE_ENFORCE(variance.empty(), + "Input 'PriorBoxVar' and attribute 'variance' should not" + "be used at the same time."); prior_box_var_data = prior_box_var->data(); prior_box_var_size = prior_box_var->dims().size(); } + if (!(variance.empty())) { + PADDLE_ENFORCE(static_cast(variance.size()) == 4, + "Size of attribute 'variance' should be 4"); + } if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } + const int var_size = static_cast(variance.size()); + thrust::device_vector dev_variance(variance.begin(), variance.end()); + const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data()); auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); int axis = context.Attr("axis"); @@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel { if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, prior_box_var_size, output); + normalized, prior_box_var_size, dev_var_data, var_size, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - normalized, prior_box_var_size, axis, output); + normalized, prior_box_var_size, dev_var_data, var_size, axis, output); } } }; diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index 986869d8a3..b61cff1b1d 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -11,6 +11,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel { void EncodeCenterSize(const framework::Tensor* target_box, const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, - const bool normalized, T* output) const { + const bool normalized, + const std::vector variance, T* output) const { int64_t row = target_box->dims()[0]; int64_t col = prior_box->dims()[0]; int64_t len = prior_box->dims()[1]; @@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel { output[offset + 1] /= prior_box_var_data[prior_var_offset + 1]; output[offset + 2] /= prior_box_var_data[prior_var_offset + 2]; output[offset + 3] /= prior_box_var_data[prior_var_offset + 3]; + } else if (!(variance.empty())) { + for (int k = 0; k < 4; ++k) { + output[offset + k] /= static_cast(variance[k]); + } } } } @@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel { const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, const bool normalized, const int axis, - T* output) const { + const std::vector variance, T* output) const { int64_t row = target_box->dims()[0]; int64_t col = target_box->dims()[1]; int64_t len = target_box->dims()[2]; @@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel { std::exp(prior_box_var_data[prior_var_offset + 3] * target_box_data[offset + 3]) * prior_box_height; + } else if (!(variance.empty())) { + target_box_center_x = static_cast(variance[0]) * + target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = static_cast(variance[1]) * + target_box_data[offset + 1] * + prior_box_height + + prior_box_center_y; + target_box_width = std::exp(static_cast(variance[2]) * + target_box_data[offset + 2]) * + prior_box_width; + target_box_height = std::exp(static_cast(variance[3]) * + target_box_data[offset + 3]) * + prior_box_height; } else { target_box_center_x = target_box_data[offset] * prior_box_width + prior_box_center_x; @@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel { auto* prior_box_var = context.Input("PriorBoxVar"); auto* target_box = context.Input("TargetBox"); auto* output_box = context.Output("OutputBox"); + std::vector variance = context.Attr>("variance"); const int axis = context.Attr("axis"); if (target_box->lod().size()) { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, "Only support 1 level of LoD."); } + if (prior_box_var) { + PADDLE_ENFORCE(variance.empty(), + "Input 'PriorBoxVar' and attribute 'variance' should not" + "be used at the same time."); + } + if (!(variance.empty())) { + PADDLE_ENFORCE(static_cast(variance.size()) == 4, + "Size of attribute 'variance' should be 4"); + } auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); @@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel { T* output = output_box->data(); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, - output); + variance, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis, - output); + variance, output); } } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 1eb876cfaf..854b34d2a4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -346,18 +346,104 @@ def box_coder(prior_box, name=None, axis=0): """ - ${comment} + **Box Coder Layer** + + Encode/Decode the target bounding box with the priorbox information. + + The Encoding schema described below: + + .. math:: + + ox = (tx - px) / pw / pxv + + oy = (ty - py) / ph / pyv + + ow = \log(\abs(tw / pw)) / pwv + + oh = \log(\abs(th / ph)) / phv + + The Decoding schema described below: + + .. math:: + + ox = (pw * pxv * tx * + px) - tw / 2 + + oy = (ph * pyv * ty * + py) - th / 2 + + ow = \exp(pwv * tw) * pw + tw / 2 + + oh = \exp(phv * th) * ph + th / 2 + + where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, + width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote + the priorbox's (anchor) center coordinates, width and height. `pxv`, + `pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`, + `ow`, `oh` denote the encoded/decoded coordinates, width and height. + + During Box Decoding, two modes for broadcast are supported. Say target + box has shape [N, M, 4], and the shape of prior box can be [N, 4] or + [M, 4]. Then prior box will broadcast to target box along the + assigned axis. Args: - prior_box(${prior_box_type}): ${prior_box_comment} - prior_box_var(${prior_box_var_type}): ${prior_box_var_comment} - target_box(${target_box_type}): ${target_box_comment} - code_type(${code_type_type}): ${code_type_comment} - box_normalized(${box_normalized_type}): ${box_normalized_comment} - axis(${axis_type}): ${axis_comment} + prior_box(Variable): Box list prior_box is a 2-D Tensor with shape + [M, 4] holds M boxes, each box is represented as + [xmin, ymin, xmax, ymax], [xmin, ymin] is the + left top coordinate of the anchor box, if the + input is image feature map, they are close to + the origin of the coordinate system. [xmax, ymax] + is the right bottom coordinate of the anchor box. + prior_box_var(Variable|list): prior_box_var supports two types of input. + One is variable with shape [M, 4] holds M group. + The other one is list consist of 4 elements + shared by all boxes. + target_box(Variable): This input can be a 2-D LoDTensor with shape + [N, 4] when code_type is 'encode_center_size'. + This input also can be a 3-D Tensor with shape + [N, M, 4] when code_type is 'decode_center_size'. + Each box is represented as + [xmin, ymin, xmax, ymax]. This tensor can + contain LoD information to represent a batch + of inputs. + code_type(string): The code type used with the target box. It can be + encode_center_size or decode_center_size + box_normalized(int): Whether treat the priorbox as a noramlized box. + Set true by default. + name(string): The name of box coder. + axis(int): Which axis in PriorBox to broadcast for box decode, + for example, if axis is 0 and TargetBox has shape + [N, M, 4] and PriorBox has shape [M, 4], then PriorBox + will broadcast to [N, M, 4] for decoding. It is only valid + when code type is decode_center_size. Set 0 by default. Returns: - output_box(${output_box_type}): ${output_box_comment} + output_box(Variable): When code_type is 'encode_center_size', the + output tensor of box_coder_op with shape + [N, M, 4] representing the result of N target + boxes encoded with M Prior boxes and variances. + When code_type is 'decode_center_size', + N represents the batch size and M represents + the number of deocded boxes. + + Examples: + + .. code-block:: python + + prior_box = fluid.layers.data(name='prior_box', + shape=[512, 4], + dtype='float32', + append_batch_size=False) + target_box = fluid.layers.data(name='target_box', + shape=[512,81,4], + dtype='float32', + append_batch_size=False) + output = fluid.layers.box_coder(prior_box=prior_box, + prior_box_var=[0.1,0.1,0.2,0.2], + target_box=target_box, + code_type="decode_center_size", + box_normalized=False, + axis=1) + """ helper = LayerHelper("box_coder", **locals()) @@ -368,18 +454,22 @@ def box_coder(prior_box, output_box = helper.create_variable( name=name, dtype=prior_box.dtype, persistable=False) + inputs = {"PriorBox": prior_box, "TargetBox": target_box} + attrs = { + "code_type": code_type, + "box_normalized": box_normalized, + "axis": axis + } + if isinstance(prior_box_var, Variable): + inputs['PriorBoxVar'] = prior_box_var + elif isinstance(prior_box_var, list): + attrs['variance'] = prior_box_var + else: + raise TypeError("Input variance of box_coder must be Variable or lisz") helper.append_op( type="box_coder", - inputs={ - "PriorBox": prior_box, - "PriorBoxVar": prior_box_var, - "TargetBox": target_box - }, - attrs={ - "code_type": code_type, - "box_normalized": box_normalized, - "axis": axis - }, + inputs=inputs, + attrs=attrs, outputs={"OutputBox": output_box}) return output_box diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 2d9ed9f9c6..2dbcfa31fc 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase): iou = layers.iou_similarity(x=x, y=y) bcoder = layers.box_coder( prior_box=x, - prior_box_var=y, + prior_box_var=[0.2, 0.3, 0.3, 0.2], target_box=z, code_type='encode_center_size') self.assertIsNotNone(iou) diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 6f7930c921..6156268bf2 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.random.random((10, 4)).astype('float32') - target_box = np.random.random((5, 10, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') + prior_box_var = np.random.random((81, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((6, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32') - target_box = np.random.random((3, 6, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[0, 1, 2, 3, 4, 5]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.ones((10, 4)).astype('float32') - target_box = np.random.random((5, 10, 4)).astype('float32') + prior_box = np.random.random((81, 4)).astype('float32') + prior_box_var = np.ones((81, 4)).astype('float32') + target_box = np.random.random((20, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest): def setUp(self): self.op_type = "box_coder" - lod = [[4, 8, 8]] - prior_box = np.random.random((10, 4)).astype('float32') - prior_box_var = np.random.random((10, 4)).astype('float32') - target_box = np.random.random((20, 4)).astype('float32') + lod = [[10, 20, 20]] + prior_box = np.random.random((20, 4)).astype('float32') + prior_box_var = np.random.random((20, 4)).astype('float32') + target_box = np.random.random((50, 4)).astype('float32') code_type = "EncodeCenterSize" box_normalized = True output_box = batch_box_coder(prior_box, prior_box_var, target_box, @@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest): def setUp(self): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((5, 4)).astype('float32') + prior_box = np.random.random((30, 4)).astype('float32') prior_box_var = np.random.random((4)).astype('float32') - target_box = np.random.random((5, 6, 4)).astype('float32') + target_box = np.random.random((30, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False axis = 1 @@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest): self.outputs = {'OutputBox': output_box} +class TestBoxCoderOpWithVariance(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[1, 1, 1, 1, 1]] + prior_box = np.random.random((30, 4)).astype('float32') + prior_box_var = np.random.random((4)).astype('float32') + target_box = np.random.random((30, 81, 4)).astype('float32') + code_type = "DecodeCenterSize" + box_normalized = False + axis = 1 + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type, box_normalized, axis) + + self.inputs = { + 'PriorBox': prior_box, + 'TargetBox': target_box, + } + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False, + 'variance': prior_box_var.astype(np.float).flatten(), + 'axis': axis + } + self.outputs = {'OutputBox': output_box} + + if __name__ == '__main__': unittest.main() From cee2e1b089f88d9a8dca530c197cb246a628e4b7 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 28 Jan 2019 05:57:33 +0000 Subject: [PATCH 9/9] refine code, test=develop --- .../fluid/operators/detection/box_coder_op.cu | 70 +++++++++---------- .../fluid/operators/detection/box_coder_op.h | 56 ++++++--------- python/paddle/fluid/tests/test_detection.py | 15 +++- 3 files changed, 67 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index 9b73572274..e078af3eb4 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -11,6 +11,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -95,47 +96,33 @@ __global__ void DecodeCenterSizeKernel( prior_box_data[prior_box_offset + 1] + prior_box_height / 2; T target_box_width, target_box_height; T target_box_center_x, target_box_center_y; + T box_var_x = T(1), box_var_y = T(1); + T box_var_w = T(1), box_var_h = T(1); if (prior_box_var_data) { int prior_var_offset = 0; if (prior_box_var_size == 2) { prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; } - target_box_width = exp(prior_box_var_data[prior_var_offset + 2] * - target_box_data[idx * len + 2]) * - prior_box_width; - target_box_height = exp(prior_box_var_data[prior_var_offset + 3] * - target_box_data[idx * len + 3]) * - prior_box_height; - target_box_center_x = prior_box_var_data[prior_var_offset] * - target_box_data[idx * len] * prior_box_width + - prior_box_center_x; - target_box_center_y = prior_box_var_data[prior_var_offset + 1] * - target_box_data[idx * len + 1] * - prior_box_height + - prior_box_center_y; + box_var_x = prior_box_var_data[prior_var_offset]; + box_var_y = prior_box_var_data[prior_var_offset + 1]; + box_var_w = prior_box_var_data[prior_var_offset + 2]; + box_var_h = prior_box_var_data[prior_var_offset + 3]; } else if (var_size == 4) { - target_box_width = - exp(static_cast(variance[2]) * target_box_data[idx * len + 2]) * - prior_box_width; - target_box_height = - exp(static_cast(variance[3]) * target_box_data[idx * len + 3]) * - prior_box_height; - target_box_center_x = static_cast(variance[0]) * - target_box_data[idx * len] * prior_box_width + - prior_box_center_x; - target_box_center_y = static_cast(variance[1]) * - target_box_data[idx * len + 1] * - prior_box_height + - prior_box_center_y; - } else { - target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width; - target_box_height = - exp(target_box_data[idx * len + 3]) * prior_box_height; - target_box_center_x = - target_box_data[idx * len] * prior_box_width + prior_box_center_x; - target_box_center_y = target_box_data[idx * len + 1] * prior_box_height + - prior_box_center_y; + box_var_x = static_cast(variance[0]); + box_var_y = static_cast(variance[1]); + box_var_w = static_cast(variance[2]); + box_var_h = static_cast(variance[3]); } + target_box_width = + exp(box_var_w * target_box_data[idx * len + 2]) * prior_box_width; + target_box_height = + exp(box_var_h * target_box_data[idx * len + 3]) * prior_box_height; + target_box_center_x = + box_var_x * target_box_data[idx * len] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[idx * len + 1] * prior_box_height + + prior_box_center_y; output[idx * len] = target_box_center_x - target_box_width / 2; output[idx * len + 1] = target_box_center_y - target_box_height / 2; @@ -177,9 +164,8 @@ class BoxCoderCUDAKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } - const int var_size = static_cast(variance.size()); - thrust::device_vector dev_variance(variance.begin(), variance.end()); - const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data()); + const int var_size = static_cast(variance.size()); + auto code_type = GetBoxCodeType(context.Attr("code_type")); bool normalized = context.Attr("box_normalized"); int axis = context.Attr("axis"); @@ -194,6 +180,16 @@ class BoxCoderCUDAKernel : public framework::OpKernel { int grid = (row * col + block - 1) / block; auto& device_ctx = context.cuda_device_context(); + auto& allocator = + platform::DeviceTemporaryAllocator::Instance().Get(device_ctx); + int bytes = var_size * sizeof(float); + auto dev_var = allocator.Allocate(bytes); + float* dev_var_data = reinterpret_cast(dev_var->ptr()); + auto cplace = platform::CPUPlace(); + const auto gplace = boost::get(context.GetPlace()); + memory::Copy(gplace, dev_var_data, cplace, &variance[0], bytes, + device_ctx.stream()); + output_box->mutable_data({row, col, len}, context.GetPlace()); T* output = output_box->data(); diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index b61cff1b1d..a0b1faf7bd 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -133,6 +133,8 @@ class BoxCoderKernel : public framework::OpKernel { T target_box_center_x = 0, target_box_center_y = 0; T target_box_width = 0, target_box_height = 0; + T box_var_x = T(1), box_var_y = T(1); + T box_var_w = T(1), box_var_h = T(1); if (prior_box_var) { int prior_var_offset = 0; if (prior_box_var->dims().size() == 2) { @@ -141,44 +143,26 @@ class BoxCoderKernel : public framework::OpKernel { else if (axis == 1) prior_var_offset = i * len; } - target_box_center_x = prior_box_var_data[prior_var_offset] * - target_box_data[offset] * prior_box_width + - prior_box_center_x; - target_box_center_y = prior_box_var_data[prior_var_offset + 1] * - target_box_data[offset + 1] * - prior_box_height + - prior_box_center_y; - target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] * - target_box_data[offset + 2]) * - prior_box_width; - target_box_height = - std::exp(prior_box_var_data[prior_var_offset + 3] * - target_box_data[offset + 3]) * - prior_box_height; + box_var_x = prior_box_var_data[prior_var_offset]; + box_var_y = prior_box_var_data[prior_var_offset + 1]; + box_var_w = prior_box_var_data[prior_var_offset + 2]; + box_var_h = prior_box_var_data[prior_var_offset + 3]; } else if (!(variance.empty())) { - target_box_center_x = static_cast(variance[0]) * - target_box_data[offset] * prior_box_width + - prior_box_center_x; - target_box_center_y = static_cast(variance[1]) * - target_box_data[offset + 1] * - prior_box_height + - prior_box_center_y; - target_box_width = std::exp(static_cast(variance[2]) * - target_box_data[offset + 2]) * - prior_box_width; - target_box_height = std::exp(static_cast(variance[3]) * - target_box_data[offset + 3]) * - prior_box_height; - } else { - target_box_center_x = - target_box_data[offset] * prior_box_width + prior_box_center_x; - target_box_center_y = target_box_data[offset + 1] * prior_box_height + - prior_box_center_y; - target_box_width = - std::exp(target_box_data[offset + 2]) * prior_box_width; - target_box_height = - std::exp(target_box_data[offset + 3]) * prior_box_height; + box_var_x = static_cast(variance[0]); + box_var_y = static_cast(variance[1]); + box_var_w = static_cast(variance[2]); + box_var_h = static_cast(variance[3]); } + target_box_center_x = + box_var_x * target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[offset + 1] * prior_box_height + + prior_box_center_y; + target_box_width = + std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width; + target_box_height = std::exp(box_var_h * target_box_data[offset + 3]) * + prior_box_height; output[offset] = target_box_center_x - target_box_width / 2; output[offset + 1] = target_box_center_y - target_box_height / 2; diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 2dbcfa31fc..869da58043 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -50,6 +50,19 @@ class TestDetection(unittest.TestCase): self.assertEqual(out.shape[-1], 6) print(str(program)) + def test_box_coder_api(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[4], dtype='float32') + y = layers.data(name='z', shape=[4], dtype='float32', lod_level=1) + bcoder = layers.box_coder( + prior_box=x, + prior_box_var=[0.1, 0.2, 0.1, 0.2], + target_box=y, + code_type='encode_center_size') + self.assertIsNotNone(bcoder) + print(str(program)) + def test_detection_api(self): program = Program() with program_guard(program): @@ -59,7 +72,7 @@ class TestDetection(unittest.TestCase): iou = layers.iou_similarity(x=x, y=y) bcoder = layers.box_coder( prior_box=x, - prior_box_var=[0.2, 0.3, 0.3, 0.2], + prior_box_var=y, target_box=z, code_type='encode_center_size') self.assertIsNotNone(iou)