|
|
|
@ -40,6 +40,11 @@ DEVICE void PrRoIPoolingDistributeDiffCUDA(T* diff, const T top_diff,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
DEVICE void GPUAccumulateRois(T* offset, T data) {
|
|
|
|
|
paddle::platform::CudaAtomicAdd(offset, data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void GPUPRROIPoolForward(
|
|
|
|
|
const int nthreads, const T* input_data, const T* input_rois,
|
|
|
|
@ -78,7 +83,7 @@ __global__ void GPUPRROIPoolForward(
|
|
|
|
|
T win_end_h = win_start_h + bin_size_h;
|
|
|
|
|
|
|
|
|
|
T win_size = max(static_cast<T>(0.0), bin_size_w * bin_size_h);
|
|
|
|
|
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
|
|
|
|
|
int input_channel = c;
|
|
|
|
|
const T* offset_input_data =
|
|
|
|
|
input_data +
|
|
|
|
|
(roi_batch_id * input_channels + input_channel) * height * width;
|
|
|
|
@ -110,10 +115,12 @@ __global__ void GPUPRROIPoolForward(
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void GPUPRROIPoolBackward(
|
|
|
|
|
const int nthreads, const T* input_rois, const T* output_grad_data,
|
|
|
|
|
const float spatial_scale, const int input_channels, const int height,
|
|
|
|
|
const int width, const int output_channels, const int pooled_height,
|
|
|
|
|
const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) {
|
|
|
|
|
const int nthreads, const T* in_data, const T* input_rois,
|
|
|
|
|
const T* output_grad_data, const float spatial_scale,
|
|
|
|
|
const int input_channels, const int height, const int width,
|
|
|
|
|
const int output_channels, const int pooled_height, const int pooled_width,
|
|
|
|
|
const int* rois_batch_id_data, T* input_grad_data, const T* out_data,
|
|
|
|
|
T* input_roi_grad_data) {
|
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int offset = blockDim.x * gridDim.x;
|
|
|
|
|
for (int i = index; i < nthreads; i += offset) {
|
|
|
|
@ -125,7 +132,7 @@ __global__ void GPUPRROIPoolBackward(
|
|
|
|
|
|
|
|
|
|
// set roi_batch_id
|
|
|
|
|
int roi_batch_id = rois_batch_id_data[n];
|
|
|
|
|
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
|
|
|
|
|
int input_channel = c;
|
|
|
|
|
int input_offset =
|
|
|
|
|
(roi_batch_id * input_channels + input_channel) * height * width;
|
|
|
|
|
T* offset_input_grad_data = input_grad_data + input_offset;
|
|
|
|
@ -137,6 +144,7 @@ __global__ void GPUPRROIPoolBackward(
|
|
|
|
|
T roi_start_h = static_cast<T>(offset_input_rois[1]) * spatial_scale;
|
|
|
|
|
T roi_end_w = static_cast<T>(offset_input_rois[2]) * spatial_scale;
|
|
|
|
|
T roi_end_h = static_cast<T>(offset_input_rois[3]) * spatial_scale;
|
|
|
|
|
T* offset_input_roi_grad_data = input_roi_grad_data + n * 4;
|
|
|
|
|
|
|
|
|
|
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(0.0));
|
|
|
|
|
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(0.0));
|
|
|
|
@ -171,6 +179,16 @@ __global__ void GPUPRROIPoolBackward(
|
|
|
|
|
height, width, PrRoIPoolingDistributeDiffCUDA<T>);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* offset_out_data = out_data + i;
|
|
|
|
|
const T* offset_in_data = in_data + input_offset;
|
|
|
|
|
PrRoIPoolingCoorBackward(
|
|
|
|
|
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w,
|
|
|
|
|
win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale,
|
|
|
|
|
offset_in_data, offset_out_data, offset_input_grad_data,
|
|
|
|
|
offset_input_roi_grad_data, GPUAccumulateRois<T>,
|
|
|
|
|
[](const T x, const T y) { return max(x, y); },
|
|
|
|
|
[](const T x, const T y) { return min(x, y); });
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -184,20 +202,15 @@ class GPUPRROIPoolOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto pooled_height = ctx.Attr<int>("pooled_height");
|
|
|
|
|
auto pooled_width = ctx.Attr<int>("pooled_width");
|
|
|
|
|
auto output_channels = ctx.Attr<int>("output_channels");
|
|
|
|
|
auto spatial_scale = ctx.Attr<float>("spatial_scale");
|
|
|
|
|
|
|
|
|
|
auto in_dims = in->dims();
|
|
|
|
|
int batch_size = in_dims[0];
|
|
|
|
|
int input_channels = in_dims[1];
|
|
|
|
|
auto output_channels = input_channels;
|
|
|
|
|
int height = in_dims[2];
|
|
|
|
|
int width = in_dims[3];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_channels,
|
|
|
|
|
output_channels * pooled_height * pooled_width,
|
|
|
|
|
"the channels of input X should equal the product of "
|
|
|
|
|
"output_channels x pooled_height x pooled_width");
|
|
|
|
|
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
if (rois_num == 0) return;
|
|
|
|
|
|
|
|
|
@ -245,17 +258,20 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* in = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* rois = ctx.Input<LoDTensor>("ROIs");
|
|
|
|
|
auto* out = ctx.Input<framework::Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* input_roi_grad =
|
|
|
|
|
ctx.Output<LoDTensor>(framework::GradVarName("ROIs"));
|
|
|
|
|
|
|
|
|
|
auto pooled_height = ctx.Attr<int>("pooled_height");
|
|
|
|
|
auto pooled_width = ctx.Attr<int>("pooled_width");
|
|
|
|
|
auto output_channels = ctx.Attr<int>("output_channels");
|
|
|
|
|
auto spatial_scale = ctx.Attr<float>("spatial_scale");
|
|
|
|
|
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
int input_channels = in->dims()[1];
|
|
|
|
|
auto output_channels = input_channels;
|
|
|
|
|
int height = in->dims()[2];
|
|
|
|
|
int width = in->dims()[3];
|
|
|
|
|
|
|
|
|
@ -280,6 +296,8 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
input_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<DeviceContext, T> set_zero;
|
|
|
|
|
set_zero(ctx.cuda_device_context(), input_grad, static_cast<T>(0));
|
|
|
|
|
input_roi_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
set_zero(ctx.cuda_device_context(), input_roi_grad, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
int output_grad_size = output_grad->numel();
|
|
|
|
|
int blocks = NumBlocks(output_grad_size);
|
|
|
|
@ -288,10 +306,12 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (output_grad_size > 0) {
|
|
|
|
|
GPUPRROIPoolBackward<
|
|
|
|
|
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
output_grad_size, rois->data<T>(), output_grad->data<T>(),
|
|
|
|
|
spatial_scale, input_channels, height, width, output_channels,
|
|
|
|
|
pooled_height, pooled_width, rois_batch_id_list_gpu.data<int>(),
|
|
|
|
|
input_grad->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
output_grad_size, in->data<T>(), rois->data<T>(),
|
|
|
|
|
output_grad->data<T>(), spatial_scale, input_channels, height,
|
|
|
|
|
width, output_channels, pooled_height, pooled_width,
|
|
|
|
|
rois_batch_id_list_gpu.data<int>(),
|
|
|
|
|
input_grad->mutable_data<T>(ctx.GetPlace()), out->data<T>(),
|
|
|
|
|
input_roi_grad->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|