|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
|
|
|
|
|
@ -115,8 +116,9 @@ __device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ void bilinear_interpolate(const T* in_data, const int channels,
|
|
|
|
|
const int width, const int height,
|
|
|
|
|
int in_n, int in_c, T in_w, T in_h,
|
|
|
|
|
T* val) {
|
|
|
|
|
int in_n, int in_c, T in_w, T in_h, T* val,
|
|
|
|
|
int out_idx, int* out2in_idx,
|
|
|
|
|
T* out2in_w) {
|
|
|
|
|
// Deal with cases that source coords are out of feature map boundary
|
|
|
|
|
if (GT<T>(-0.5, in_w) || GT<T>(in_w, width - 0.5) || GT<T>(-0.5, in_h) ||
|
|
|
|
|
GT<T>(in_h, height - 0.5)) {
|
|
|
|
@ -165,6 +167,16 @@ __device__ void bilinear_interpolate(const T* in_data, const int channels,
|
|
|
|
|
T w3 = w_floor * h_floor;
|
|
|
|
|
T w4 = w_floor * h_ceil;
|
|
|
|
|
val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
|
|
|
|
|
|
|
|
|
|
int base_idx = (in_n * channels + in_c) * height * width;
|
|
|
|
|
out2in_idx[out_idx * 4] = base_idx + in_h_floor * width + in_w_floor;
|
|
|
|
|
out2in_idx[out_idx * 4 + 1] = base_idx + in_h_ceil * width + in_w_floor;
|
|
|
|
|
out2in_idx[out_idx * 4 + 2] = base_idx + in_h_ceil * width + in_w_ceil;
|
|
|
|
|
out2in_idx[out_idx * 4 + 3] = base_idx + in_h_floor * width + in_w_ceil;
|
|
|
|
|
out2in_w[out_idx * 4] = w1;
|
|
|
|
|
out2in_w[out_idx * 4 + 1] = w2;
|
|
|
|
|
out2in_w[out_idx * 4 + 2] = w3;
|
|
|
|
|
out2in_w[out_idx * 4 + 3] = w4;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -262,13 +274,11 @@ __device__ void get_transform_matrix(const int transformed_width,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void RoiTransformKernel(const float* input_data,
|
|
|
|
|
const float* rois_data,
|
|
|
|
|
const int* roi2image_data, int num_rois,
|
|
|
|
|
int in_height, int in_width, int channels,
|
|
|
|
|
int transformed_height,
|
|
|
|
|
int transformed_width, float spatial_scale,
|
|
|
|
|
T* output_data) {
|
|
|
|
|
__global__ void RoiTransformKernel(
|
|
|
|
|
const float* input_data, const float* rois_data, const int* roi2image_data,
|
|
|
|
|
int num_rois, int in_height, int in_width, int channels,
|
|
|
|
|
int transformed_height, int transformed_width, float spatial_scale,
|
|
|
|
|
T* output_data, int* out2in_idx, T* out2in_w) {
|
|
|
|
|
int output_size =
|
|
|
|
|
num_rois * transformed_height * transformed_width * channels;
|
|
|
|
|
|
|
|
|
@ -311,7 +321,8 @@ __global__ void RoiTransformKernel(const float* input_data,
|
|
|
|
|
// Perform bilinear interpolation
|
|
|
|
|
int in_n = roi2image_data[n];
|
|
|
|
|
bilinear_interpolate<T>(input_data, channels, in_width, in_height, in_n,
|
|
|
|
|
c, in_w, in_h, output_data + index);
|
|
|
|
|
c, in_w, in_h, output_data + index, index,
|
|
|
|
|
out2in_idx, out2in_w);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
@ -328,6 +339,16 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* in = ctx.Input<framework::Tensor>("X");
|
|
|
|
|
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
|
|
|
|
|
auto* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
|
auto* out2in_idx = ctx.Output<framework::Tensor>("Out2InIdx");
|
|
|
|
|
auto* out2in_w = ctx.Output<framework::Tensor>("Out2InWeights");
|
|
|
|
|
|
|
|
|
|
int* out2in_idx_data =
|
|
|
|
|
out2in_idx->mutable_data<int>({out->numel(), 4}, ctx.GetPlace());
|
|
|
|
|
T* out2in_w_data =
|
|
|
|
|
out2in_w->mutable_data<T>({out->numel(), 4}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, int> init;
|
|
|
|
|
init(ctx.cuda_device_context(), out2in_idx, static_cast<int>(-1));
|
|
|
|
|
|
|
|
|
|
auto transformed_height = ctx.Attr<int>("transformed_height");
|
|
|
|
|
auto transformed_width = ctx.Attr<int>("transformed_width");
|
|
|
|
@ -364,7 +385,7 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
RoiTransformKernel<T><<<grid, block, 0, stream>>>(
|
|
|
|
|
input_data, rois_data, roi2image_dev.data<int>(), rois_num, in_height,
|
|
|
|
|
in_width, channels, transformed_height, transformed_width,
|
|
|
|
|
spatial_scale, output_data);
|
|
|
|
|
spatial_scale, output_data, out2in_idx_data, out2in_w_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -420,100 +441,42 @@ __device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void RoiTransformGradKernel(
|
|
|
|
|
const size_t* lod, const T* rois_data, int batch_size, int num_rois,
|
|
|
|
|
int in_height, int in_width, int channels, int transformed_height,
|
|
|
|
|
int transformed_width, float spatial_scale, const T* out_grad_data,
|
|
|
|
|
__global__ void RoiTransformGradKernel(int out_size, const int* out2in_idx_data,
|
|
|
|
|
const T* out2in_w_data,
|
|
|
|
|
const T* out_grad_data,
|
|
|
|
|
T* in_grad_data) {
|
|
|
|
|
int input_size = batch_size * in_height * in_width * channels;
|
|
|
|
|
|
|
|
|
|
CUDA_1D_KERNEL_LOOP(index, input_size) {
|
|
|
|
|
// (n, c, h, w) coords in input
|
|
|
|
|
int in_w = idx4_4(index, batch_size, channels, in_height, in_width);
|
|
|
|
|
int in_h = idx4_3(index, batch_size, channels, in_height, in_width);
|
|
|
|
|
int c = idx4_2(index, batch_size, channels, in_height, in_width);
|
|
|
|
|
int n = idx4_1(index, batch_size, channels, in_height, in_width);
|
|
|
|
|
|
|
|
|
|
T gradient = 0.0;
|
|
|
|
|
// Accumulate gradient over all RoIs that interpolated this element
|
|
|
|
|
for (size_t roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) {
|
|
|
|
|
const T* rois = rois_data + roi_idx * 8;
|
|
|
|
|
T roi_x[4];
|
|
|
|
|
T roi_y[4];
|
|
|
|
|
for (int k = 0; k < 4; ++k) {
|
|
|
|
|
roi_x[k] = rois[2 * k] * spatial_scale;
|
|
|
|
|
roi_y[k] = rois[2 * k + 1] * spatial_scale;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get transform matrix
|
|
|
|
|
T matrix[9];
|
|
|
|
|
get_transform_matrix<T>(transformed_width, transformed_height, roi_x,
|
|
|
|
|
roi_y, matrix);
|
|
|
|
|
|
|
|
|
|
const T* out_grad_ptr =
|
|
|
|
|
out_grad_data +
|
|
|
|
|
(roi_idx * channels + c) * transformed_height * transformed_width;
|
|
|
|
|
for (int out_h = 0; out_h < transformed_height; ++out_h) {
|
|
|
|
|
for (int out_w = 0; out_w < transformed_width; ++out_w) {
|
|
|
|
|
T src_w;
|
|
|
|
|
T src_h;
|
|
|
|
|
get_source_coords<T>(matrix, out_w, out_h, &src_w, &src_h);
|
|
|
|
|
if (in_quad<T>(src_w, src_h, roi_x, roi_y)) {
|
|
|
|
|
if (GT<T>(-0.5, src_w) ||
|
|
|
|
|
GT<T>(src_w, static_cast<T>(in_width - 0.5)) ||
|
|
|
|
|
GT<T>(-0.5, src_h) ||
|
|
|
|
|
GT<T>(src_h, static_cast<T>(in_height - 0.5))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
T weight = get_feature_gradient<T>(src_w, src_h, in_w, in_h,
|
|
|
|
|
in_width, in_height);
|
|
|
|
|
gradient +=
|
|
|
|
|
out_grad_ptr[out_h * transformed_width + out_w] * weight;
|
|
|
|
|
CUDA_1D_KERNEL_LOOP(index, out_size * 4) {
|
|
|
|
|
int in_idx = out2in_idx_data[index];
|
|
|
|
|
if (in_idx >= 0) {
|
|
|
|
|
int out_idx = index / 4;
|
|
|
|
|
atomicAdd(in_grad_data + in_idx,
|
|
|
|
|
out_grad_data[out_idx] * out2in_w_data[index]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
in_grad_data[index] = gradient;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* in = ctx.Input<framework::Tensor>("X");
|
|
|
|
|
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
|
|
|
|
|
auto* out2in_idx = ctx.Input<framework::LoDTensor>("Out2InIdx");
|
|
|
|
|
auto* out2in_w = ctx.Input<framework::LoDTensor>("Out2InWeights");
|
|
|
|
|
auto* out_grad =
|
|
|
|
|
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
|
|
auto transformed_height = ctx.Attr<int>("transformed_height");
|
|
|
|
|
auto transformed_width = ctx.Attr<int>("transformed_width");
|
|
|
|
|
auto spatial_scale = ctx.Attr<float>("spatial_scale");
|
|
|
|
|
|
|
|
|
|
auto in_dims = in->dims();
|
|
|
|
|
int batch_size = in_dims[0];
|
|
|
|
|
int channels = in_dims[1];
|
|
|
|
|
int in_height = in_dims[2];
|
|
|
|
|
int in_width = in_dims[3];
|
|
|
|
|
int rois_num = rois->dims()[0];
|
|
|
|
|
|
|
|
|
|
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
const T* out_grad_data = out_grad->data<T>();
|
|
|
|
|
const T* rois_data = rois->data<T>();
|
|
|
|
|
|
|
|
|
|
auto lod = rois->lod().back();
|
|
|
|
|
auto lod_data = lod.CUDAData(ctx.GetPlace());
|
|
|
|
|
const int* out2in_idx_data = out2in_idx->data<int>();
|
|
|
|
|
const T* out2in_w_data = out2in_w->data<T>();
|
|
|
|
|
|
|
|
|
|
int in_size = in->numel();
|
|
|
|
|
int out_size = out_grad->numel();
|
|
|
|
|
auto stream = ctx.cuda_device_context().stream();
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (in_size + block - 1) / block;
|
|
|
|
|
int grid = (out_size * 4 + block - 1) / block;
|
|
|
|
|
|
|
|
|
|
RoiTransformGradKernel<T><<<grid, block, 0, stream>>>(
|
|
|
|
|
lod_data, rois_data, batch_size, rois_num, in_height, in_width,
|
|
|
|
|
channels, transformed_height, transformed_width, spatial_scale,
|
|
|
|
|
out_grad_data, in_grad_data);
|
|
|
|
|
out_size, out2in_idx_data, out2in_w_data, out_grad_data, in_grad_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|