|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/detection/yolo_box_op.h"
|
|
|
|
#include "paddle/fluid/operators/detection/yolo_box_op.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
@ -22,11 +23,12 @@ using Tensor = framework::Tensor;
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
|
|
|
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
|
|
|
T* scores, const float conf_thresh,
|
|
|
|
T* scores, const float conf_thresh,
|
|
|
|
std::vector<int> anchors, const int h, const in w,
|
|
|
|
const int* anchors, const int h, const int w,
|
|
|
|
const int an_num, const int class_num,
|
|
|
|
const int an_num, const int class_num,
|
|
|
|
const int box_num, const int input_size) {
|
|
|
|
const int box_num, int input_size) {
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
|
|
|
T box[4];
|
|
|
|
for (; tid < box_num; tid += stride) {
|
|
|
|
for (; tid < box_num; tid += stride) {
|
|
|
|
int grid_num = h * w;
|
|
|
|
int grid_num = h * w;
|
|
|
|
int i = tid / box_num;
|
|
|
|
int i = tid / box_num;
|
|
|
@ -47,10 +49,10 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
|
|
|
|
|
|
|
|
|
|
|
|
int box_idx =
|
|
|
|
int box_idx =
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
|
|
|
|
Box<T> pred = GetYoloBox<T>(input, anchors, l, k, j, h, input_size, box_idx,
|
|
|
|
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
|
|
|
|
grid_num, img_height, img_width);
|
|
|
|
grid_num, img_height, img_width);
|
|
|
|
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
|
|
|
|
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
|
|
|
|
CalcDetectionBox<T>(boxes, pred, box_idx);
|
|
|
|
CalcDetectionBox<T>(boxes, box, box_idx);
|
|
|
|
|
|
|
|
|
|
|
|
int label_idx =
|
|
|
|
int label_idx =
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
|
|
|
|
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
|
|
|
@ -64,7 +66,7 @@ template <typename T>
|
|
|
|
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
auto* input = ctx.Input<Tensor>("Input");
|
|
|
|
auto* input = ctx.Input<Tensor>("X");
|
|
|
|
auto* img_size = ctx.Input<Tensor>("ImgSize");
|
|
|
|
auto* img_size = ctx.Input<Tensor>("ImgSize");
|
|
|
|
auto* boxes = ctx.Output<Tensor>("Boxes");
|
|
|
|
auto* boxes = ctx.Output<Tensor>("Boxes");
|
|
|
|
auto* scores = ctx.Output<Tensor>("Scores");
|
|
|
|
auto* scores = ctx.Output<Tensor>("Scores");
|
|
|
@ -81,23 +83,35 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
const int an_num = anchors.size() / 2;
|
|
|
|
const int an_num = anchors.size() / 2;
|
|
|
|
int input_size = downsample_ratio * h;
|
|
|
|
int input_size = downsample_ratio * h;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Tensor anchors_t, cpu_anchors_t;
|
|
|
|
|
|
|
|
auto cpu_anchors_data = cpu_anchors_t.mutable_data<int>({an_num*2}, platform::CPUPlace());
|
|
|
|
|
|
|
|
std::copy(anchors.begin(), anchors.end(), cpu_anchors_data);
|
|
|
|
|
|
|
|
TensorCopySync(cpu_anchors_t, ctx.GetPlace(), &anchors_t);
|
|
|
|
|
|
|
|
auto anchors_data = anchors_t.data<int>();
|
|
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
const int* imgsize_data = imgsize->data<int>();
|
|
|
|
const int* imgsize_data = img_size->data<int>();
|
|
|
|
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
|
|
|
|
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
|
|
|
|
memset(boxes_data, 0, boxes->numel() * sizeof(T));
|
|
|
|
|
|
|
|
T* scores_data =
|
|
|
|
T* scores_data =
|
|
|
|
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
|
|
|
|
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
|
|
|
|
memset(scores_data, 0, scores->numel() * sizeof(T));
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
|
|
|
set_zero(dev_ctx, boxes, static_cast<T>(0));
|
|
|
|
|
|
|
|
set_zero(dev_ctx, scores, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
|
|
int grid_dim = (n * box_num + 512 - 1) / 512;
|
|
|
|
int grid_dim = (n * box_num + 512 - 1) / 512;
|
|
|
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
|
|
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
|
|
|
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
|
|
|
|
|
|
|
|
anchors_data, h, w, an_num, class_num, box_num, input_size);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}; // namespace operators
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OP_CUDA_KERNEL(density_prior_box,
|
|
|
|
REGISTER_OP_CUDA_KERNEL(yolo_box,
|
|
|
|
ops::DensityPriorBoxOpCUDAKernel<float>,
|
|
|
|
ops::YoloBoxOpCUDAKernel<float>,
|
|
|
|
ops::DensityPriorBoxOpCUDAKernel<double>);
|
|
|
|
ops::YoloBoxOpCUDAKernel<double>);
|
|
|
|