|
|
|
@ -26,7 +26,8 @@ __global__ void KeNearestNeighborInterpFw(
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (tid < nthreads) {
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
for (; tid < nthreads; tid += stride) {
|
|
|
|
|
int out_id_h = tid / output_w;
|
|
|
|
|
int out_id_w = tid % output_w;
|
|
|
|
|
int in_img_size = input_w / num_channels;
|
|
|
|
@ -52,7 +53,8 @@ __global__ void KeNearestNeighborInterpBw(
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (tid < nthreads) {
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
for (; tid < nthreads; tid += stride) {
|
|
|
|
|
int out_id_h = tid / output_w;
|
|
|
|
|
int out_id_w = tid % output_w;
|
|
|
|
|
int in_img_size = input_w / num_channels;
|
|
|
|
@ -80,7 +82,8 @@ __global__ void KeBilinearInterpFw(
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (tid < nthreads) {
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
for (; tid < nthreads; tid += stride) {
|
|
|
|
|
int out_id_h = tid / output_w;
|
|
|
|
|
int out_id_w = tid % output_w;
|
|
|
|
|
int in_img_size = input_w / num_channels;
|
|
|
|
@ -118,7 +121,8 @@ __global__ void KeBilinearInterpBw(
|
|
|
|
|
const size_t num_channels, const T ratio_h, const T ratio_w) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (tid < nthreads) {
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
for (; tid < nthreads; tid += stride) {
|
|
|
|
|
int out_id_h = tid / output_w;
|
|
|
|
|
int out_id_w = tid % output_w;
|
|
|
|
|
int in_img_size = input_w / num_channels;
|
|
|
|
@ -194,17 +198,18 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int threadNum = n * out_chw;
|
|
|
|
|
int blocks = (threadNum + 1024 - 1) / 1024;
|
|
|
|
|
int pixelNum = n * out_chw;
|
|
|
|
|
int grid_dim = (pixelNum + 512 - 1) / 512;
|
|
|
|
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
|
|
|
|
|
|
|
|
|
if ("nearest" == interp_method) {
|
|
|
|
|
KeNearestNeighborInterpFw<
|
|
|
|
|
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
|
|
|
|
|
out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
} else if ("bilinear" == interp_method) {
|
|
|
|
|
KeBilinearInterpFw<
|
|
|
|
|
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
|
|
|
|
|
out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
}
|
|
|
|
@ -257,17 +262,18 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int threadNum = n * out_chw;
|
|
|
|
|
int blocks = (threadNum + 1024 - 1) / 1024;
|
|
|
|
|
int pixelNum = n * out_chw;
|
|
|
|
|
int grid_dim = (pixelNum + 512 - 1) / 512;
|
|
|
|
|
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
|
|
|
|
|
|
|
|
|
if ("nearest" == interp_method) {
|
|
|
|
|
KeNearestNeighborInterpBw<
|
|
|
|
|
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
|
|
|
|
|
out_w, n, out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
} else if ("bilinear" == interp_method) {
|
|
|
|
|
KeBilinearInterpBw<
|
|
|
|
|
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
|
|
|
|
|
out_w, n, out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
}
|
|
|
|
|