|
|
|
@ -19,10 +19,10 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
static constexpr int kNumCUDAThreads = 512;
|
|
|
|
|
static constexpr int kNumMaxinumNumBlocks = 4096;
|
|
|
|
|
static constexpr int kROISize = 5;
|
|
|
|
|
|
|
|
|
|
static inline int NumBlocks(const int N) {
|
|
|
|
|
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
|
|
|
|
@ -30,13 +30,11 @@ static inline int NumBlocks(const int N) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void GPUROIPoolForward(const int nthreads, const T* input_data,
|
|
|
|
|
const int64_t* input_rois,
|
|
|
|
|
const float spatial_scale, const int channels,
|
|
|
|
|
const int height, const int width,
|
|
|
|
|
const int pooled_height,
|
|
|
|
|
const int pooled_width, T* output_data,
|
|
|
|
|
int64_t* argmax_data) {
|
|
|
|
|
__global__ void GPUROIPoolForward(
|
|
|
|
|
const int nthreads, const T* input_data, const int64_t* input_rois,
|
|
|
|
|
const float spatial_scale, const int channels, const int height,
|
|
|
|
|
const int width, const int pooled_height, const int pooled_width,
|
|
|
|
|
int* roi_batch_id_data, T* output_data, int64_t* argmax_data) {
|
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int offset = blockDim.x * gridDim.x;
|
|
|
|
|
for (size_t i = index; i < nthreads; i += offset) {
|
|
|
|
@ -46,11 +44,11 @@ __global__ void GPUROIPoolForward(const int nthreads, const T* input_data,
|
|
|
|
|
int n = index / pooled_width / pooled_height / channels;
|
|
|
|
|
|
|
|
|
|
const int64_t* offset_input_rois = input_rois + n * kROISize;
|
|
|
|
|
int roi_batch_ind = offset_input_rois[0];
|
|
|
|
|
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
|
|
|
|
|
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
|
|
|
|
|
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
|
|
|
|
|
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
|
|
|
|
|
int roi_batch_ind = roi_batch_id_data[n];
|
|
|
|
|
int roi_start_w = round(offset_input_rois[0] * spatial_scale);
|
|
|
|
|
int roi_start_h = round(offset_input_rois[1] * spatial_scale);
|
|
|
|
|
int roi_end_w = round(offset_input_rois[2] * spatial_scale);
|
|
|
|
|
int roi_end_h = round(offset_input_rois[3] * spatial_scale);
|
|
|
|
|
|
|
|
|
|
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
|
|
|
|
|
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
|
|
|
|
@ -93,7 +91,8 @@ __global__ void GPUROIPoolBackward(
|
|
|
|
|
const int nthreads, const int64_t* input_rois, const T* output_grad,
|
|
|
|
|
const int64_t* argmax_data, const int num_rois, const float spatial_scale,
|
|
|
|
|
const int channels, const int height, const int width,
|
|
|
|
|
const int pooled_height, const int pooled_width, T* input_grad) {
|
|
|
|
|
const int pooled_height, const int pooled_width, int* roi_batch_id_data,
|
|
|
|
|
T* input_grad) {
|
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int offset = blockDim.x * gridDim.x;
|
|
|
|
|
for (int i = index; i < nthreads; i += offset) {
|
|
|
|
@ -102,8 +101,7 @@ __global__ void GPUROIPoolBackward(
|
|
|
|
|
int c = (index / pooled_width / pooled_height) % channels;
|
|
|
|
|
int n = index / pooled_width / pooled_height / channels;
|
|
|
|
|
|
|
|
|
|
const int64_t* offset_input_rois = input_rois + n * kROISize;
|
|
|
|
|
int roi_batch_ind = offset_input_rois[0];
|
|
|
|
|
int roi_batch_ind = roi_batch_id_data[n];
|
|
|
|
|
int input_offset = (roi_batch_ind * channels + c) * height * width;
|
|
|
|
|
int output_offset = (n * channels + c) * pooled_height * pooled_width;
|
|
|
|
|
const T* offset_output_grad = output_grad + output_offset;
|
|
|
|
@ -124,7 +122,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* in = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* rois = ctx.Input<Tensor>("ROIs");
|
|
|
|
|
auto* rois = ctx.Input<LoDTensor>("ROIs");
|
|
|
|
|
auto* out = ctx.Output<Tensor>("Out");
|
|
|
|
|
auto* argmax = ctx.Output<Tensor>("Argmax");
|
|
|
|
|
|
|
|
|
@ -133,23 +131,46 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto spatial_scale = ctx.Attr<float>("spatial_scale");
|
|
|
|
|
|
|
|
|
|
auto in_dims = in->dims();
|
|
|
|
|
int batch_size = in_dims[0];
|
|
|
|
|
auto in_stride = framework::stride(in_dims);
|
|
|
|
|
int channels = in_dims[1];
|
|
|
|
|
int height = in_dims[2];
|
|
|
|
|
int width = in_dims[3];
|
|
|
|
|
|
|
|
|
|
size_t rois_num = rois->dims()[0];
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
if (rois_num == 0) return;
|
|
|
|
|
|
|
|
|
|
int output_size = out->numel();
|
|
|
|
|
int blocks = NumBlocks(output_size);
|
|
|
|
|
int threads = kNumCUDAThreads;
|
|
|
|
|
|
|
|
|
|
framework::Tensor roi_batch_id_list;
|
|
|
|
|
roi_batch_id_list.Resize({rois_num});
|
|
|
|
|
int* roi_batch_id_data =
|
|
|
|
|
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
|
|
|
|
|
auto rois_lod = rois->lod().back();
|
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rois_batch_size, batch_size,
|
|
|
|
|
"The rois_batch_size and imgs batch_size must be the same.");
|
|
|
|
|
int rois_num_with_lod = rois_lod[rois_batch_size];
|
|
|
|
|
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
|
|
|
|
|
"The rois_num from input and lod must be the same.");
|
|
|
|
|
for (int n = 0; n < rois_batch_size; ++n) {
|
|
|
|
|
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
|
|
|
|
|
roi_batch_id_data[i] = n;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Tensor roi_batch_id_list_gpu;
|
|
|
|
|
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
|
|
|
|
|
ctx.device_context(), &roi_batch_id_list_gpu);
|
|
|
|
|
|
|
|
|
|
GPUROIPoolForward<
|
|
|
|
|
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
|
|
|
|
|
channels, height, width, pooled_height, pooled_width,
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
argmax->mutable_data<int64_t>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -159,7 +180,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* in = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* rois = ctx.Input<Tensor>("ROIs");
|
|
|
|
|
auto* rois = ctx.Input<LoDTensor>("ROIs");
|
|
|
|
|
auto* argmax = ctx.Input<Tensor>("Argmax");
|
|
|
|
|
|
|
|
|
|
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
@ -169,12 +190,27 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto pooled_width = ctx.Attr<int>("pooled_width");
|
|
|
|
|
auto spatial_scale = ctx.Attr<float>("spatial_scale");
|
|
|
|
|
|
|
|
|
|
size_t rois_num = rois->dims()[0];
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
int channels = in->dims()[1];
|
|
|
|
|
int height = in->dims()[2];
|
|
|
|
|
int width = in->dims()[3];
|
|
|
|
|
|
|
|
|
|
if (x_grad) {
|
|
|
|
|
framework::Tensor roi_batch_id_list;
|
|
|
|
|
roi_batch_id_list.Resize({rois_num});
|
|
|
|
|
int* roi_batch_id_data =
|
|
|
|
|
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
|
|
|
|
|
auto rois_lod = rois->lod().back();
|
|
|
|
|
int rois_batch_size = rois_lod.size() - 1;
|
|
|
|
|
for (int n = 0; n < rois_batch_size; ++n) {
|
|
|
|
|
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
|
|
|
|
|
roi_batch_id_data[i] = n;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
framework::Tensor roi_batch_id_list_gpu;
|
|
|
|
|
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
|
|
|
|
|
ctx.device_context(), &roi_batch_id_list_gpu);
|
|
|
|
|
|
|
|
|
|
x_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<Place, T> set_zero;
|
|
|
|
|
set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));
|
|
|
|
@ -189,6 +225,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
|
|
|
|
|
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
|
|
|
|
|
width, pooled_height, pooled_width,
|
|
|
|
|
roi_batch_id_list_gpu.data<int>(),
|
|
|
|
|
x_grad->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|