|
|
|
@ -506,6 +506,206 @@ __global__ void KeTrilinearInterpBw(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ __forceinline__ static T Kecubic_interp(const T x0, const T x1,
|
|
|
|
|
const T x2, const T x3,
|
|
|
|
|
T t) {
|
|
|
|
|
T coeffs[4];
|
|
|
|
|
T a = -0.75;
|
|
|
|
|
T x_1 = t;
|
|
|
|
|
T x_2 = 1.0 - t;
|
|
|
|
|
coeffs[0] = cubic_convolution2<T>(x_1 + 1.0, a);
|
|
|
|
|
coeffs[1] = cubic_convolution1<T>(x_1, a);
|
|
|
|
|
coeffs[2] = cubic_convolution1<T>(x_2, a);
|
|
|
|
|
coeffs[3] = cubic_convolution2<T>(x_2 + 1.0, a);
|
|
|
|
|
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void KeBicubicInterpFw(
|
|
|
|
|
const T* in, const size_t in_img_h, const size_t in_img_w,
|
|
|
|
|
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
|
|
|
|
|
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w,
|
|
|
|
|
const bool align_corners, const DataLayout data_layout) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
|
|
|
|
|
for (; tid < nthreads; tid += stride) {
|
|
|
|
|
int out_id_h = tid / output_w;
|
|
|
|
|
int out_id_w = tid % output_w;
|
|
|
|
|
int in_img_size = input_w / num_channels;
|
|
|
|
|
int out_img_size = output_w / num_channels;
|
|
|
|
|
|
|
|
|
|
int channel_id, out_img_idy, out_img_idx;
|
|
|
|
|
|
|
|
|
|
if (data_layout == DataLayout::kNCHW) {
|
|
|
|
|
channel_id = out_id_w / out_img_size;
|
|
|
|
|
out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
|
|
|
|
out_img_idx = tid % out_img_w;
|
|
|
|
|
} else {
|
|
|
|
|
out_img_idy = out_id_w / (out_img_w * num_channels);
|
|
|
|
|
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
|
|
|
|
|
channel_id = tid % num_channels;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T in_img_idy = align_corners
|
|
|
|
|
? static_cast<T>(ratio_h * out_img_idy)
|
|
|
|
|
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
|
|
|
|
|
int input_y = static_cast<int>(in_img_idy);
|
|
|
|
|
const T y_t = in_img_idy - input_y;
|
|
|
|
|
|
|
|
|
|
T in_img_idx = align_corners
|
|
|
|
|
? static_cast<T>(ratio_w * out_img_idx)
|
|
|
|
|
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
|
|
|
|
|
int input_x = static_cast<int>(in_img_idx);
|
|
|
|
|
const T x_t = in_img_idx - input_x;
|
|
|
|
|
|
|
|
|
|
T coefficients[4];
|
|
|
|
|
const T* in_pos_0;
|
|
|
|
|
const T* in_pos_1;
|
|
|
|
|
const T* in_pos_2;
|
|
|
|
|
const T* in_pos_3;
|
|
|
|
|
int access_x_0;
|
|
|
|
|
if (data_layout == DataLayout::kNCHW) {
|
|
|
|
|
for (int k = 0; k < 4; k++) {
|
|
|
|
|
int access_y =
|
|
|
|
|
max(min(input_y - 1 + k, static_cast<int>(in_img_h - 1)), 0);
|
|
|
|
|
access_x_0 = max(min(input_x - 1, static_cast<int>(in_img_w - 1)), 0);
|
|
|
|
|
int access_x_1 =
|
|
|
|
|
max(min(input_x + 0, static_cast<int>(in_img_w - 1)), 0);
|
|
|
|
|
int access_x_2 =
|
|
|
|
|
max(min(input_x + 1, static_cast<int>(in_img_w - 1)), 0);
|
|
|
|
|
int access_x_3 =
|
|
|
|
|
max(min(input_x + 2, static_cast<int>(in_img_w - 1)), 0);
|
|
|
|
|
|
|
|
|
|
in_pos_0 = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
|
access_y * in_img_w + access_x_0];
|
|
|
|
|
in_pos_1 = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
|
access_y * in_img_w + access_x_1];
|
|
|
|
|
in_pos_2 = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
|
access_y * in_img_w + access_x_2];
|
|
|
|
|
in_pos_3 = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
|
access_y * in_img_w + access_x_3];
|
|
|
|
|
|
|
|
|
|
coefficients[k] = Kecubic_interp<T>(in_pos_0[0], in_pos_1[0],
|
|
|
|
|
in_pos_2[0], in_pos_3[0], x_t);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out[out_id_h * output_w + out_id_w] =
|
|
|
|
|
Kecubic_interp<T>(coefficients[0], coefficients[1], coefficients[2],
|
|
|
|
|
coefficients[3], y_t);
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
for (int k = 0; k < 4; k++) {
|
|
|
|
|
int access_y =
|
|
|
|
|
max(min(input_y - 1 + k, static_cast<int>((in_img_h - 1))), 0);
|
|
|
|
|
int access_x_0 =
|
|
|
|
|
max(min(input_x - 1, static_cast<int>((in_img_w - 1))), 0);
|
|
|
|
|
int access_x_1 =
|
|
|
|
|
max(min(input_x + 0, static_cast<int>((in_img_w - 1))), 0);
|
|
|
|
|
int access_x_2 =
|
|
|
|
|
max(min(input_x + 1, static_cast<int>((in_img_w - 1))), 0);
|
|
|
|
|
int access_x_3 =
|
|
|
|
|
max(min(input_x + 2, static_cast<int>((in_img_w - 1))), 0);
|
|
|
|
|
|
|
|
|
|
const T* in_pos_0 =
|
|
|
|
|
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
|
|
|
|
|
access_x_0 * num_channels + channel_id];
|
|
|
|
|
const T* in_pos_1 =
|
|
|
|
|
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
|
|
|
|
|
access_x_1 * num_channels + channel_id];
|
|
|
|
|
const T* in_pos_2 =
|
|
|
|
|
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
|
|
|
|
|
access_x_2 * num_channels + channel_id];
|
|
|
|
|
const T* in_pos_3 =
|
|
|
|
|
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
|
|
|
|
|
access_x_3 * num_channels + channel_id];
|
|
|
|
|
|
|
|
|
|
coefficients[k] = Kecubic_interp(in_pos_0[0], in_pos_1[0], in_pos_2[0],
|
|
|
|
|
in_pos_3[0], x_t);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out[out_id_h * output_w + out_id_w] =
|
|
|
|
|
static_cast<T>(Kecubic_interp(coefficients[0], coefficients[1],
|
|
|
|
|
coefficients[2], coefficients[3], y_t));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void KeBicubicInterpBw(
|
|
|
|
|
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
|
|
|
|
|
const size_t input_w, const T* out, const size_t out_img_h,
|
|
|
|
|
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w,
|
|
|
|
|
const bool align_corners, const DataLayout data_layout) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
|
|
|
|
|
for (; tid < nthreads; tid += stride) {
|
|
|
|
|
int out_id_h = tid / output_w;
|
|
|
|
|
int out_id_w = tid % output_w;
|
|
|
|
|
int in_img_size = input_w / num_channels;
|
|
|
|
|
int out_img_size = output_w / num_channels;
|
|
|
|
|
|
|
|
|
|
int channel_id, out_img_idy, out_img_idx;
|
|
|
|
|
if (data_layout == DataLayout::kNCHW) {
|
|
|
|
|
channel_id = out_id_w / out_img_size;
|
|
|
|
|
out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
|
|
|
|
out_img_idx = tid % out_img_w;
|
|
|
|
|
} else {
|
|
|
|
|
out_img_idy = out_id_w / (out_img_w * num_channels);
|
|
|
|
|
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
|
|
|
|
|
channel_id = tid % num_channels;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T in_img_idy = align_corners
|
|
|
|
|
? static_cast<T>(ratio_h * out_img_idy)
|
|
|
|
|
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
|
|
|
|
|
int input_y = static_cast<int>(in_img_idy);
|
|
|
|
|
const T y_t = in_img_idy - input_y;
|
|
|
|
|
|
|
|
|
|
T in_img_idx = align_corners
|
|
|
|
|
? static_cast<T>(ratio_w * out_img_idx)
|
|
|
|
|
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
|
|
|
|
|
int input_x = static_cast<int>(in_img_idx);
|
|
|
|
|
|
|
|
|
|
const T x_t = in_img_idx - input_x;
|
|
|
|
|
|
|
|
|
|
T x_coeffs[4];
|
|
|
|
|
T y_coeffs[4];
|
|
|
|
|
|
|
|
|
|
get_cubic_upsample_coefficients(x_coeffs, x_t);
|
|
|
|
|
get_cubic_upsample_coefficients(y_coeffs, y_t);
|
|
|
|
|
|
|
|
|
|
const T* out_pos = &out[out_id_h * output_w + out_id_w];
|
|
|
|
|
T* in_pos;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 4; i++) {
|
|
|
|
|
for (int j = 0; j < 4; j++) {
|
|
|
|
|
int access_y = max(min(static_cast<int>(input_y - 1 + j),
|
|
|
|
|
static_cast<int>(in_img_h - 1)),
|
|
|
|
|
0);
|
|
|
|
|
int access_x = max(min(static_cast<int>(input_x - 1 + i),
|
|
|
|
|
static_cast<int>(in_img_w - 1)),
|
|
|
|
|
0);
|
|
|
|
|
if (data_layout == DataLayout::kNCHW) {
|
|
|
|
|
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
|
access_y * in_img_w + access_x];
|
|
|
|
|
} else {
|
|
|
|
|
in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels +
|
|
|
|
|
access_x * num_channels + channel_id];
|
|
|
|
|
}
|
|
|
|
|
platform::CudaAtomicAdd(&in_pos[0],
|
|
|
|
|
(out_pos[0] * y_coeffs[j] * x_coeffs[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
|
|
|
|
|
const Tensor& input, Tensor* output) {
|
|
|
|
@ -602,6 +802,11 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
|
|
|
|
|
ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
|
|
|
|
|
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
|
|
|
|
|
} else if ("bicubic" == interp_method) {
|
|
|
|
|
KeBicubicInterpFw<
|
|
|
|
|
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
|
|
|
|
|
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -806,6 +1011,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
|
|
|
|
|
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
|
|
|
|
|
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
|
|
|
|
|
data_layout);
|
|
|
|
|
} else if ("bicubic" == interp_method) {
|
|
|
|
|
KeBicubicInterpBw<
|
|
|
|
|
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
|
|
|
|
|
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -968,3 +1178,9 @@ REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel<float>,
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad,
|
|
|
|
|
ops::InterpolateGradOpCUDAKernel<float>,
|
|
|
|
|
ops::InterpolateGradOpCUDAKernel<double>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(bicubic_interp, ops::InterpolateOpCUDAKernel<float>,
|
|
|
|
|
ops::InterpolateOpCUDAKernel<double>,
|
|
|
|
|
ops::InterpolateOpCUDAKernel<int>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(bicubic_interp_grad,
|
|
|
|
|
ops::InterpolateGradOpCUDAKernel<float>,
|
|
|
|
|
ops::InterpolateGradOpCUDAKernel<double>);
|
|
|
|
|