|
|
|
@ -23,7 +23,8 @@ __global__ void KeNearestNeighborInterpFw(
|
|
|
|
|
const T* in, const size_t in_img_h, const size_t in_img_w,
|
|
|
|
|
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
|
|
|
|
|
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w) {
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w,
|
|
|
|
|
const bool align_corners) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
@ -35,10 +36,14 @@ __global__ void KeNearestNeighborInterpFw(
|
|
|
|
|
int channel_id = out_id_w / out_img_size;
|
|
|
|
|
|
|
|
|
|
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
|
|
|
|
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
|
|
|
|
|
int in_img_idy = (align_corners)
|
|
|
|
|
? static_cast<int>(ratio_h * out_img_idy + 0.5)
|
|
|
|
|
: static_cast<int>(ratio_h * out_img_idy);
|
|
|
|
|
|
|
|
|
|
int out_img_idx = tid % out_img_w;
|
|
|
|
|
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
|
|
|
|
|
int in_img_idx = (align_corners)
|
|
|
|
|
? static_cast<int>(ratio_w * out_img_idx + 0.5)
|
|
|
|
|
: static_cast<int>(ratio_w * out_img_idx);
|
|
|
|
|
|
|
|
|
|
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
|
in_img_idy * in_img_w + in_img_idx];
|
|
|
|
@ -50,7 +55,8 @@ __global__ void KeNearestNeighborInterpBw(
|
|
|
|
|
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
|
|
|
|
|
const size_t input_w, const T* out, const size_t out_img_h,
|
|
|
|
|
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w) {
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w,
|
|
|
|
|
const bool align_corners) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
@ -62,10 +68,14 @@ __global__ void KeNearestNeighborInterpBw(
|
|
|
|
|
int channel_id = out_id_w / out_img_size;
|
|
|
|
|
|
|
|
|
|
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
|
|
|
|
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
|
|
|
|
|
int in_img_idy = (align_corners)
|
|
|
|
|
? static_cast<int>(ratio_h * out_img_idy + 0.5)
|
|
|
|
|
: static_cast<int>(ratio_h * out_img_idy);
|
|
|
|
|
|
|
|
|
|
int out_img_idx = tid % out_img_w;
|
|
|
|
|
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
|
|
|
|
|
int in_img_idx = (align_corners)
|
|
|
|
|
? static_cast<int>(ratio_w * out_img_idx + 0.5)
|
|
|
|
|
: static_cast<int>(ratio_w * out_img_idx);
|
|
|
|
|
|
|
|
|
|
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
|
in_img_idy * in_img_w + in_img_idx];
|
|
|
|
@ -79,7 +89,8 @@ __global__ void KeBilinearInterpFw(
|
|
|
|
|
const T* in, const size_t in_img_h, const size_t in_img_w,
|
|
|
|
|
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
|
|
|
|
|
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w) {
|
|
|
|
|
const size_t num_channels, const float ratio_h, const float ratio_w,
|
|
|
|
|
const bool align_corners, const int align_mode) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
@ -91,15 +102,23 @@ __global__ void KeBilinearInterpFw(
|
|
|
|
|
int channel_id = out_id_w / out_img_size;
|
|
|
|
|
|
|
|
|
|
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
|
|
|
|
int in_img_idy = ratio_h * out_img_idy;
|
|
|
|
|
int in_img_idy = (align_mode == 0 && !align_corners)
|
|
|
|
|
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
|
|
|
|
|
: static_cast<int>(ratio_h * out_img_idy);
|
|
|
|
|
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
|
|
|
|
|
T h1lambda = ratio_h * out_img_idy - in_img_idy;
|
|
|
|
|
T h1lambda = (align_mode == 0 && !align_corners)
|
|
|
|
|
? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy
|
|
|
|
|
: ratio_h * out_img_idy - in_img_idy;
|
|
|
|
|
T h2lambda = 1.f - h1lambda;
|
|
|
|
|
|
|
|
|
|
int out_img_idx = tid % out_img_w;
|
|
|
|
|
int in_img_idx = ratio_w * out_img_idx;
|
|
|
|
|
int in_img_idx = (align_mode == 0 && !align_corners)
|
|
|
|
|
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
|
|
|
|
|
: static_cast<int>(ratio_w * out_img_idx);
|
|
|
|
|
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
|
|
|
|
|
T w1lambda = ratio_w * out_img_idx - in_img_idx;
|
|
|
|
|
T w1lambda = (align_mode == 0 && !align_corners)
|
|
|
|
|
? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx
|
|
|
|
|
: ratio_w * out_img_idx - in_img_idx;
|
|
|
|
|
T w2lambda = 1.f - w1lambda;
|
|
|
|
|
|
|
|
|
|
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
@ -118,7 +137,8 @@ __global__ void KeBilinearInterpBw(
|
|
|
|
|
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
|
|
|
|
|
const size_t input_w, const T* out, const size_t out_img_h,
|
|
|
|
|
const size_t out_img_w, const size_t output_h, const size_t output_w,
|
|
|
|
|
const size_t num_channels, const T ratio_h, const T ratio_w) {
|
|
|
|
|
const size_t num_channels, const T ratio_h, const T ratio_w,
|
|
|
|
|
const bool align_corners, const int align_mode) {
|
|
|
|
|
int nthreads = output_h * output_w;
|
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
@ -130,15 +150,24 @@ __global__ void KeBilinearInterpBw(
|
|
|
|
|
int channel_id = out_id_w / out_img_size;
|
|
|
|
|
|
|
|
|
|
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
|
|
|
|
|
int in_img_idy = ratio_h * out_img_idy;
|
|
|
|
|
int in_img_idy = (align_mode == 0 && !align_corners)
|
|
|
|
|
? ratio_h * (out_img_idy + 0.5) - 0.5
|
|
|
|
|
: ratio_h * out_img_idy;
|
|
|
|
|
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
|
|
|
|
|
T h1lambda = ratio_h * out_img_idy - in_img_idy;
|
|
|
|
|
T h1lambda = (align_mode == 0 && !align_corners)
|
|
|
|
|
? ratio_h * (out_img_idy + 0.5) - 0.5 - in_img_idy
|
|
|
|
|
: ratio_h * out_img_idy - in_img_idy;
|
|
|
|
|
|
|
|
|
|
T h2lambda = 1.f - h1lambda;
|
|
|
|
|
|
|
|
|
|
int out_img_idx = tid % out_img_w;
|
|
|
|
|
int in_img_idx = ratio_w * out_img_idx;
|
|
|
|
|
int in_img_idx = (align_mode == 0 && !align_corners)
|
|
|
|
|
? ratio_w * (out_img_idx + 0.5) - 0.5
|
|
|
|
|
: ratio_w * out_img_idx;
|
|
|
|
|
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
|
|
|
|
|
T w1lambda = ratio_w * out_img_idx - in_img_idx;
|
|
|
|
|
T w1lambda = (align_mode == 0 && !align_corners)
|
|
|
|
|
? ratio_w * (out_img_idx + 0.5) - 0.5 - in_img_idx
|
|
|
|
|
: ratio_w * out_img_idx - in_img_idx;
|
|
|
|
|
T w2lambda = 1.f - w1lambda;
|
|
|
|
|
|
|
|
|
|
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
|
|
|
|
@ -175,6 +204,9 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
out_w = size_data[1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool align_corners = ctx.Attr<bool>("align_corners");
|
|
|
|
|
int align_mode = ctx.Attr<int>("align_mode");
|
|
|
|
|
|
|
|
|
|
int n = input->dims()[0];
|
|
|
|
|
int c = input->dims()[1];
|
|
|
|
|
int in_h = input->dims()[2];
|
|
|
|
@ -188,10 +220,12 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
int in_chw = c * in_hw;
|
|
|
|
|
int out_chw = c * out_hw;
|
|
|
|
|
|
|
|
|
|
float ratio_h =
|
|
|
|
|
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
|
|
|
|
float ratio_w =
|
|
|
|
|
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
|
|
|
|
float ratio_h = (align_corners && out_h > 1)
|
|
|
|
|
? static_cast<float>(in_h - 1) / (out_h - 1)
|
|
|
|
|
: static_cast<float>(in_h) / out_h;
|
|
|
|
|
float ratio_w = (align_corners && out_w > 1)
|
|
|
|
|
? static_cast<float>(in_w - 1) / (out_w - 1)
|
|
|
|
|
: static_cast<float>(in_w) / out_w;
|
|
|
|
|
|
|
|
|
|
if (in_h == out_h && in_w == out_w) {
|
|
|
|
|
framework::TensorCopy(*input, ctx.GetPlace(), output);
|
|
|
|
@ -206,12 +240,12 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
KeNearestNeighborInterpFw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
|
|
|
|
|
out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
out_chw, c, ratio_h, ratio_w, align_corners);
|
|
|
|
|
} else if ("bilinear" == interp_method) {
|
|
|
|
|
KeBilinearInterpFw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
|
|
|
|
|
out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -234,6 +268,10 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
int out_h = ctx.Attr<int>("out_h");
|
|
|
|
|
int out_w = ctx.Attr<int>("out_w");
|
|
|
|
|
auto out_size = ctx.Input<Tensor>("OutSize");
|
|
|
|
|
|
|
|
|
|
bool align_corners = ctx.Attr<bool>("align_corners");
|
|
|
|
|
int align_mode = ctx.Attr<int>("align_mode");
|
|
|
|
|
|
|
|
|
|
if (out_size != nullptr) {
|
|
|
|
|
Tensor sizes;
|
|
|
|
|
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
|
|
|
|
@ -252,10 +290,12 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
int in_chw = c * in_hw;
|
|
|
|
|
int out_chw = c * out_hw;
|
|
|
|
|
|
|
|
|
|
float ratio_h =
|
|
|
|
|
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
|
|
|
|
|
float ratio_w =
|
|
|
|
|
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
|
|
|
|
|
float ratio_h = (align_corners && out_h > 1)
|
|
|
|
|
? static_cast<float>(in_h - 1) / (out_h - 1)
|
|
|
|
|
: static_cast<float>(in_h) / out_h;
|
|
|
|
|
float ratio_w = (align_corners && out_w > 1)
|
|
|
|
|
? static_cast<float>(in_w - 1) / (out_w - 1)
|
|
|
|
|
: static_cast<float>(in_w) / out_w;
|
|
|
|
|
|
|
|
|
|
if (in_h == out_h && in_w == out_w) {
|
|
|
|
|
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
|
|
|
|
@ -270,12 +310,12 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
KeNearestNeighborInterpBw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
|
|
|
|
|
out_w, n, out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners);
|
|
|
|
|
} else if ("bilinear" == interp_method) {
|
|
|
|
|
KeBilinearInterpBw<
|
|
|
|
|
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
|
|
|
|
|
out_w, n, out_chw, c, ratio_h, ratio_w);
|
|
|
|
|
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|