Add `nn.interpolate ` (#23434)

* add nn.interpolate, test=develop

* fix interpolate typo, test=develop

* formate code, test=develop

* fix unitest, test=develop

* add test layers, test=develop

* add test layers, test=develop

* extract common function, test=develop

* reduce the threads for cuda10, test=develop

* update unitest, test=develop

* polish unitest, test=develop

* add dygraph test, test=develop

* format description, test=develop

* add 5D input check, test=develop

* fix doc, test=develop
revert-22778-infer_var_type
xiaoting 5 years ago committed by GitHub
parent 5fe3b63824
commit 7de0a25b5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,8 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"bilinear" == interp_method || "nearest" == interp_method ||
"bicubic" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
const DataLayout data_layout = framework::StringToDataLayout(
@ -264,7 +265,8 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
"method, can be \"bilinear\" for "
"bilinear interpolation, \"trilinear\" for trilinear "
"interpolation and \"nearest\" for nearest "
"neighbor interpolation.")
"neighbor interpolation, and \"bicubic\" for bicubic"
"interpolation.")
.SetDefault("bilinear");
AddAttr<bool>(
"align_corners",
@ -299,6 +301,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Bicubic interpolation is an extension of cubic interpolation for interpolating
data points on a two-dimensional regular grid. The interpolated surface is
smoother than corresponding surfaces obtained by bilinear interpolation or
nearest-neighbor interpolation.
Align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
@ -376,7 +383,20 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Bicubic interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
@ -386,6 +406,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation
For details of bicubic interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bicubic_interpolation
)DOC");
}
};
@ -469,6 +492,11 @@ REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
@ -484,3 +512,7 @@ REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);

@ -506,6 +506,206 @@ __global__ void KeTrilinearInterpBw(
}
}
template <typename T>
__device__ __forceinline__ static T Kecubic_interp(const T x0, const T x1,
const T x2, const T x3,
T t) {
T coeffs[4];
T a = -0.75;
T x_1 = t;
T x_2 = 1.0 - t;
coeffs[0] = cubic_convolution2<T>(x_1 + 1.0, a);
coeffs[1] = cubic_convolution1<T>(x_1, a);
coeffs[2] = cubic_convolution1<T>(x_2, a);
coeffs[3] = cubic_convolution2<T>(x_2 + 1.0, a);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
__global__ void KeBicubicInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = static_cast<int>(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = static_cast<int>(in_img_idx);
const T x_t = in_img_idx - input_x;
T coefficients[4];
const T* in_pos_0;
const T* in_pos_1;
const T* in_pos_2;
const T* in_pos_3;
int access_x_0;
if (data_layout == DataLayout::kNCHW) {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>(in_img_h - 1)), 0);
access_x_0 = max(min(input_x - 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>(in_img_w - 1)), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>(in_img_w - 1)), 0);
in_pos_0 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_0];
in_pos_1 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_1];
in_pos_2 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_2];
in_pos_3 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_3];
coefficients[k] = Kecubic_interp<T>(in_pos_0[0], in_pos_1[0],
in_pos_2[0], in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
Kecubic_interp<T>(coefficients[0], coefficients[1], coefficients[2],
coefficients[3], y_t);
} else {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>((in_img_h - 1))), 0);
int access_x_0 =
max(min(input_x - 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>((in_img_w - 1))), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>((in_img_w - 1))), 0);
const T* in_pos_0 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_0 * num_channels + channel_id];
const T* in_pos_1 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_1 * num_channels + channel_id];
const T* in_pos_2 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_2 * num_channels + channel_id];
const T* in_pos_3 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_3 * num_channels + channel_id];
coefficients[k] = Kecubic_interp(in_pos_0[0], in_pos_1[0], in_pos_2[0],
in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
static_cast<T>(Kecubic_interp(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t));
}
}
}
template <typename T>
__global__ void KeBicubicInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = static_cast<int>(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = static_cast<int>(in_img_idx);
const T x_t = in_img_idx - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients(x_coeffs, x_t);
get_cubic_upsample_coefficients(y_coeffs, y_t);
const T* out_pos = &out[out_id_h * output_w + out_id_w];
T* in_pos;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
int access_y = max(min(static_cast<int>(input_y - 1 + j),
static_cast<int>(in_img_h - 1)),
0);
int access_x = max(min(static_cast<int>(input_x - 1 + i),
static_cast<int>(in_img_w - 1)),
0);
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x];
} else {
in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x * num_channels + channel_id];
}
platform::CudaAtomicAdd(&in_pos[0],
(out_pos[0] * y_coeffs[j] * x_coeffs[i]));
}
}
}
}
template <typename T>
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
@ -602,6 +802,11 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpFw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
@ -806,6 +1011,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpBw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
@ -968,3 +1178,9 @@ REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);

@ -10,10 +10,12 @@
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
@ -342,6 +344,106 @@ static void TrilinearInterpolation(
}
}
template <typename T>
HOSTDEVICE inline T cubic_convolution1(T x, T A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename T>
HOSTDEVICE inline T cubic_convolution2(T x, T A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
T A = -0.75;
T x1 = t;
coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
coeffs[1] = cubic_convolution1<T>(x1, A);
// opposite coefficients
T x2 = 1.0 - t;
coeffs[2] = cubic_convolution1<T>(x2, A);
coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
}
template <typename T>
static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
T coeffs[4];
get_cubic_upsample_coefficients<T>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
static void BicubicInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = static_cast<int>(y_n);
const T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = static_cast<int>(x_n);
const T x_t = x_n - input_x;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
T coefficients[4];
// interp 4 times in x direction
for (int ii = 0; ii < 4; ii++) {
int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
static_cast<int>(0));
int access_x_0 =
std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
int access_x_1 =
std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
int access_x_2 =
std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
int access_x_3 =
std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
coefficients[ii] =
cubic_interp<T>(input_t(i, j, access_y, access_x_0),
input_t(i, j, access_y, access_x_1),
input_t(i, j, access_y, access_x_2),
input_t(i, j, access_y, access_x_3), x_t);
} else {
coefficients[ii] =
cubic_interp<T>(input_t(i, access_y, access_x_0, j),
input_t(i, access_y, access_x_1, j),
input_t(i, access_y, access_x_2, j),
input_t(i, access_y, access_x_3, j), x_t);
}
}
// interp y direction
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
} else {
output_t(i, k, l, j) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
}
}
}
}
}
}
template <typename T>
static void NearestNeighborInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
@ -509,6 +611,61 @@ static void TrilinearInterpolationGrad(
}
}
template <typename T>
static void BicubicInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h,
const int in_w, const int n, const int c,
const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = static_cast<int>(y_n);
T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = static_cast<int>(x_n);
T x_t = x_n - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
get_cubic_upsample_coefficients<T>(y_coeffs, y_t);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bicubic interpolation grad
for (int ii = 0; ii < 4; ii++) {
for (int jj = 0; jj < 4; jj++) {
int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
static_cast<int>(0));
int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, access_y, access_x) +=
grad * y_coeffs[jj] * x_coeffs[ii];
} else {
T grad = output_grad_t(i, k, l, j);
input_grad_t(i, access_y, access_x, j) +=
grad * y_coeffs[jj] * x_coeffs[ii];
}
}
}
}
}
}
}
}
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
@ -587,6 +744,9 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
out_w, align_corners, data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
out_h, out_w, align_corners, data_layout);
}
}
@ -759,6 +919,10 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
n, c, out_h, out_w, align_corners,
data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w, in_h,
in_w, n, c, out_h, out_w, align_corners,
data_layout);
}
}

@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.nn.functional import *
def trilinear_interp_np(input,
@ -586,6 +587,15 @@ class TestTrilinearInterpAPI(unittest.TestCase):
out4 = fluid.layers.resize_trilinear(
x, out_shape=[4, 4, 8], actual_shape=actual_size)
out5 = fluid.layers.resize_trilinear(x, scale=scale_tensor)
out6 = interpolate(
x, scale=scale_tensor, resample='TRILINEAR', data_format="NCDHW")
out7 = interpolate(
x, out_shape=[4, 4, 8], resample='TRILINEAR', data_format="NCDHW")
out8 = interpolate(
x,
out_shape=shape_tensor,
resample='TRILINEAR',
data_format="NCDHW")
x_data = np.random.random((2, 3, 6, 9, 4)).astype("float32")
dim_data = np.array([18]).astype("int32")

@ -187,4 +187,4 @@ from .extension import row_conv #DEFINE_ALIAS
# from .common import unfold #DEFINE_ALIAS
# from .common import bilinear_tensor_product #DEFINE_ALIAS
# from .common import assign #DEFINE_ALIAS
# from .common import interpolate #DEFINE_ALIAS
from .common import interpolate #DEFINE_ALIAS

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save