【2.0 API】Add CUDA kernel and enhance options for grid_sample (#26576)

This PR enhance CPU kernel and add new CUDA kernel to make grid_sample support:

- align_corners: with bool type.
- padding mode: which can be in ['zeros', 'reflect', 'border']
- Interpolation mode: which ca be in ['bilinear', 'nearest']

The old CPU and CUDNN version only support align_corners=true, padding_mode='zeros' and interpolation_mode='bilinear'.

The behavior of the new version op in default mode is compatible with the old version.
test_feature_precision_test_c
whs 5 years ago committed by GitHub
parent 39fe0d35aa
commit 79539cf198
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,13 +41,14 @@ class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
int n = input->dims()[0];
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
const int size[4] = {n, c, h, w};
int out_h = grid->dims()[1];
int out_w = grid->dims()[2];
const int size[4] = {n, c, out_h, out_w};
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
T* output_data = output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
T* output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
@ -97,7 +98,7 @@ class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
const T* grid_data = grid->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
input_grad->mutable_data<T>(input->dims(), ctx.GetPlace());
T* grid_grad_data =
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());

@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
@ -58,21 +59,10 @@ class GridSampleOp : public framework::OperatorWithKernel {
"Input(X) and Input(Grid) dimension[0] should be equal, but "
"received X dimension[0](%d) != Grid dimension[0](%d)",
x_dims[0], grid_dims[0]));
PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2],
platform::errors::InvalidArgument(
"Input(X) dims[2] and Input(Grid) dims[1] should be equal, but "
"received X dimension[2](%d) != Grid dimension[1](%d)",
x_dims[2], grid_dims[1]));
PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3],
platform::errors::InvalidArgument(
"Input(X) dims[3] and Input(Grid) dims[2] should be equal, but "
"received X dimension[3](%d) != Grid dimension[2](%d)",
x_dims[3], grid_dims[2]));
}
ctx->SetOutputDim("Output", x_dims);
ctx->SetOutputDim("Output",
{x_dims[0], x_dims[1], grid_dims[1], grid_dims[2]});
ctx->ShareLoD("X", "Output");
}
@ -108,15 +98,37 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default true) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddAttr<bool>(
"align_corners",
"(bool, default true) If align_corners is true, it will project"
"-1 and 1 to the centers of the corner pixels. Otherwise, it will "
"project"
"-1 and 1 to the image edges.")
.SetDefault(true);
AddAttr<std::string>(
"mode",
"(bool, default true) The interpolation method which can be 'bilinear'"
" or 'nearest'.")
.SetDefault("bilinear");
AddAttr<std::string>(
"padding_mode",
"(bool, default true) The padding method used when source"
"index is out of input images. It can be 'zeros', 'reflect' and "
"'border'.")
.SetDefault("zeros");
AddComment(R"DOC(
This operation samples input X by using bilinear interpolation based on
This operation samples input X by using bilinear or nearest interpolation based on
flow field grid, which is usually generated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexing the 3rd
dimension (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
interpolation value or nearest value of 4 nearest corner points.
For bilinear interpolation mode:
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,131 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import fluid, nn
import paddle.fluid.dygraph as dg
import paddle.nn.functional as F
import unittest
class GridSampleTestCase(unittest.TestCase):
def __init__(self,
methodName='runTest',
x_shape=[2, 2, 3, 3],
grid_shape=[2, 3, 3, 2],
mode="bilinear",
padding_mode="zeros",
align_corners=False):
super(GridSampleTestCase, self).__init__(methodName)
self.padding_mode = padding_mode
self.x_shape = x_shape
self.grid_shape = grid_shape
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
self.dtype = "float64"
def setUp(self):
self.x = np.random.randn(*(self.x_shape)).astype(self.dtype)
self.grid = np.random.uniform(-1, 1, self.grid_shape).astype(self.dtype)
def static_functional(self, place):
main = fluid.Program()
start = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, start):
x = fluid.data("x", self.x_shape, dtype=self.dtype)
grid = fluid.data("grid", self.grid_shape, dtype=self.dtype)
y_var = F.grid_sample(
x,
grid,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners)
feed_dict = {"x": self.x, "grid": self.grid}
exe = fluid.Executor(place)
exe.run(start)
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
return y_np
def dynamic_functional(self):
x_t = paddle.to_tensor(self.x)
grid_t = paddle.to_tensor(self.grid)
y_t = F.grid_sample(
x_t,
grid_t,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners)
y_np = y_t.numpy()
return y_np
def _test_equivalence(self, place):
result1 = self.static_functional(place)
with dg.guard(place):
result2 = self.dynamic_functional()
np.testing.assert_array_almost_equal(result1, result2)
def runTest(self):
place = fluid.CPUPlace()
self._test_equivalence(place)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self._test_equivalence(place)
class GridSampleErrorTestCase(GridSampleTestCase):
def runTest(self):
place = fluid.CPUPlace()
with self.assertRaises(ValueError):
self.static_functional(place)
def add_cases(suite):
suite.addTest(GridSampleTestCase(methodName='runTest'))
suite.addTest(
GridSampleTestCase(
methodName='runTest',
mode='bilinear',
padding_mode='reflect',
align_corners=True))
suite.addTest(
GridSampleTestCase(
methodName='runTest',
mode='bilinear',
padding_mode='zeros',
align_corners=True))
def add_error_cases(suite):
suite.addTest(
GridSampleErrorTestCase(
methodName='runTest', padding_mode="VALID"))
suite.addTest(
GridSampleErrorTestCase(
methodName='runTest', align_corners="VALID"))
suite.addTest(GridSampleErrorTestCase(methodName='runTest', mode="VALID"))
def load_tests(loader, standard_tests, pattern):
suite = unittest.TestSuite()
add_cases(suite)
add_error_cases(suite)
return suite
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

@ -192,7 +192,7 @@ from .vision import fsp_matrix #DEFINE_ALIAS
from .vision import generate_mask_labels #DEFINE_ALIAS
from .vision import generate_proposal_labels #DEFINE_ALIAS
from .vision import generate_proposals #DEFINE_ALIAS
from .vision import grid_sampler #DEFINE_ALIAS
from .vision import grid_sample #DEFINE_ALIAS
from .vision import image_resize #DEFINE_ALIAS
from .vision import image_resize_short #DEFINE_ALIAS
# from .vision import multi_box_head #DEFINE_ALIAS

@ -28,7 +28,6 @@ from ...fluid.layers import distribute_fpn_proposals #DEFINE_ALIAS
from ...fluid.layers import generate_mask_labels #DEFINE_ALIAS
from ...fluid.layers import generate_proposal_labels #DEFINE_ALIAS
from ...fluid.layers import generate_proposals #DEFINE_ALIAS
from ...fluid.layers import grid_sampler #DEFINE_ALIAS
from ...fluid.layers import image_resize #DEFINE_ALIAS
from ...fluid.layers import prior_box #DEFINE_ALIAS
from ...fluid.layers import prroi_pool #DEFINE_ALIAS
@ -68,7 +67,7 @@ __all__ = [
'generate_mask_labels',
'generate_proposal_labels',
'generate_proposals',
'grid_sampler',
'grid_sample',
'image_resize',
'image_resize_short',
# 'multi_box_head',
@ -89,3 +88,187 @@ __all__ = [
'yolo_box',
'yolov3_loss'
]
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid import core, dygraph_utils
from ...fluid.framework import Variable, in_dygraph_mode
from ...device import get_cudnn_version
import numpy as np
def grid_sample(x,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
name=None):
"""
This operation samples input X by using bilinear interpolation or
nearest interpolation based on flow field grid, which is usually
generated by :code:`affine_grid` . The grid of shape [N, H, W, 2]
is the concatenation of (x, y) coordinates with shape [N, H, W] each,
where x is indexing the 4th dimension (in width dimension) of input
data x and y is indexing the 3rd dimension (in height dimension),
finally results is the bilinear interpolation or nearest value of 4 nearest corner
points. The output tensor shape will be [N, C, H, W].
.. code-block:: text
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
.. code-block:: text
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points or nearest interpolate point value
by nearest point.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
For bilinear interpolation:
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Tensor): The input tensor, which is a 4-d tensor with shape
[N, C, H, W], N is the batch size, C is the channel
number, H and W is the feature height and width.
The data type is float32 or float64.
grid(Tensor): Input grid tensor of shape [N, grid_H, grid_W, 2]. The
data type is float32 or float64.
mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'.
Default: 'bilinear'.
padding_mode(str, optional) The padding method used when source index
is out of input images. It can be 'zeros', 'reflect' and 'border'.
Default: zeros.
align_corners(bool, optional): If `align_corners` is true, it will projects
-1 and 1 to the centers of the corner pixels. Otherwise, it will
projects -1 and 1 to the image edges.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns: Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid
and `grid_W` is the width of grid. The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
# shape=[1, 1, 3, 3]
x = np.array([[[[-0.6, 0.8, -0.5],
[-0.5, 0.2, 1.2],
[ 1.4, 0.3, -0.2]]]]).astype("float64")
# grid shape = [1, 3, 4, 2]
grid = np.array(
[[[[ 0.2, 0.3],
[-0.4, -0.3],
[-0.9, 0.3],
[-0.9, -0.6]],
[[ 0.4, 0.1],
[ 0.9, -0.8],
[ 0.4, 0.5],
[ 0.5, -0.2]],
[[ 0.1, -0.8],
[-0.3, -1. ],
[ 0.7, 0.4],
[ 0.2, 0.8]]]]).astype("float64")
paddle.disable_static()
x = paddle.to_tensor(x)
grid = paddle.to_tensor(grid)
y_t = F.grid_sample(
x,
grid,
mode='bilinear',
padding_mode='border',
align_corners=True)
print(y_t.numpy())
# output shape = [1, 1, 3, 4]
# [[[[ 0.34 0.016 0.086 -0.448]
# [ 0.55 -0.076 0.35 0.59 ]
# [ 0.596 0.38 0.52 0.24 ]]]]
"""
helper = LayerHelper("grid_sample", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'grid_sampler')
check_variable_and_dtype(grid, 'grid', ['float32', 'float64'],
'grid_sampler')
if not isinstance(x, Variable):
raise ValueError("The x should be a Variable")
if not isinstance(grid, Variable):
raise ValueError("The grid should be a Variable")
_modes = ['bilinear', 'nearest']
_padding_modes = ['zeros', 'reflect', 'border']
if mode not in _modes:
raise ValueError(
"The mode of grid sample function should be in {}, but got: {}".
format(_modes, mode))
if padding_mode not in _padding_modes:
raise ValueError(
"The padding mode of grid sample function should be in {}, but got: {}".
format(_padding_modes, padding_mode))
if not isinstance(align_corners, bool):
raise ValueError("The align corners should be bool, but got: {}".format(
align_corners))
cudnn_version = get_cudnn_version()
use_cudnn = False
if (cudnn_version is not None
) and align_corners and mode == 'bilinear' and padding_mode == 'zeros':
use_cudnn = True
ipts = {'X': x, 'Grid': grid}
attrs = {
'mode': mode,
'padding_mode': padding_mode,
'align_corners': align_corners,
'use_cudnn': use_cudnn
}
if in_dygraph_mode():
attrs = ('mode', mode, 'padding_mode', padding_mode, 'align_corners',
align_corners, 'use_cudnn', use_cudnn)
out = getattr(core.ops, 'grid_sampler')(x, grid, *attrs)
else:
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='grid_sampler',
inputs=ipts,
attrs=attrs,
outputs={'Output': out})
return out

Loading…
Cancel
Save