add linear interpolate operator (#23357)

* test=develop add linear interpolate operator
release/2.0-alpha
FDInSky 5 years ago committed by GitHub
parent e24575c8ba
commit 7fedf26b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -17,10 +17,12 @@
from .layer import norm from .layer import norm
from .functional import extension from .functional import extension
from .layer import common
__all__ = [] __all__ = []
__all__ += norm.__all__ __all__ += norm.__all__
__all__ += extension.__all__ __all__ += extension.__all__
__all__ += common.__all__
# TODO: define alias in nn directory # TODO: define alias in nn directory
# from .clip import ErrorClipByValue #DEFINE_ALIAS # from .clip import ErrorClipByValue #DEFINE_ALIAS
@ -64,7 +66,7 @@ from .layer.common import BilinearTensorProduct #DEFINE_ALIAS
from .layer.common import Pool2D #DEFINE_ALIAS from .layer.common import Pool2D #DEFINE_ALIAS
from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Embedding #DEFINE_ALIAS
from .layer.common import Linear #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS
# from .layer.common import UpSample #DEFINE_ALIAS from .layer.common import UpSample #DEFINE_ALIAS
from .layer.conv import Conv2D #DEFINE_ALIAS from .layer.conv import Conv2D #DEFINE_ALIAS
from .layer.conv import Conv2DTranspose #DEFINE_ALIAS from .layer.conv import Conv2DTranspose #DEFINE_ALIAS
from .layer.conv import Conv3D #DEFINE_ALIAS from .layer.conv import Conv3D #DEFINE_ALIAS

@ -284,6 +284,7 @@ def interpolate(input,
# [2L, 3L, 12L, 12L] # [2L, 3L, 12L, 12L]
""" """
resample_methods = { resample_methods = {
'LINEAR': 'linear',
'BILINEAR': 'bilinear', 'BILINEAR': 'bilinear',
'TRILINEAR': 'trilinear', 'TRILINEAR': 'trilinear',
'NEAREST': 'nearest', 'NEAREST': 'nearest',
@ -291,10 +292,13 @@ def interpolate(input,
} }
if resample not in resample_methods: if resample not in resample_methods:
raise ValueError( raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR', " "The 'resample' of image_resize can only be 'LINEAR', 'BILINEAR', 'TRILINEAR', "
" 'BICUBIC' or 'NEAREST' currently.") " 'BICUBIC' or 'NEAREST' currently.")
resample_type = resample_methods[resample] resample_type = resample_methods[resample]
if resample in ['LINEAR'] and len(input.shape) != 3:
raise ValueError("'LINEAR' only support 3-D tensor.")
if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(input.shape) != 4: if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(input.shape) != 4:
raise ValueError( raise ValueError(
"'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.") "'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.")
@ -311,7 +315,11 @@ def interpolate(input,
helper = LayerHelper('{}_interp'.format(resample_type), **locals()) helper = LayerHelper('{}_interp'.format(resample_type), **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']: if len(input.shape) == 3 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 3-D input.")
elif len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
raise ValueError( raise ValueError(
"Got wrong value for param `data_format`: " + data_format + "Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 4-D input.") " received but only `NCHW` or `NHWC` supported for 4-D input.")
@ -375,6 +383,15 @@ def interpolate(input,
size_list.append(dim) size_list.append(dim)
inputs['SizeTensor'] = new_size_tensor inputs['SizeTensor'] = new_size_tensor
if len(input.shape) == 3:
if len(out_shape) != 1:
raise ValueError(
"out_shape length should be 2 for input 3-D tensor")
if contain_var:
attrs['out_w'] = size_list[0]
else:
out_shape = list(map(int, out_shape))
attrs['out_w'] = out_shape[0]
if len(input.shape) == 4: if len(input.shape) == 4:
if len(out_shape) != 2: if len(out_shape) != 2:
raise ValueError("out_shape length should be 2 for " raise ValueError("out_shape length should be 2 for "

@ -37,7 +37,7 @@ from .common import BilinearTensorProduct #DEFINE_ALIAS
from .common import Pool2D #DEFINE_ALIAS from .common import Pool2D #DEFINE_ALIAS
from .common import Embedding #DEFINE_ALIAS from .common import Embedding #DEFINE_ALIAS
from .common import Linear #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS
# from .common import UpSample #DEFINE_ALIAS from .common import UpSample #DEFINE_ALIAS
from .conv import Conv2D #DEFINE_ALIAS from .conv import Conv2D #DEFINE_ALIAS
from .conv import Conv2DTranspose #DEFINE_ALIAS from .conv import Conv2DTranspose #DEFINE_ALIAS
from .conv import Conv3D #DEFINE_ALIAS from .conv import Conv3D #DEFINE_ALIAS

@ -17,11 +17,232 @@ from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS
from ...fluid.dygraph import Pool2D #DEFINE_ALIAS from ...fluid.dygraph import Pool2D #DEFINE_ALIAS
from ...fluid.dygraph import Embedding #DEFINE_ALIAS from ...fluid.dygraph import Embedding #DEFINE_ALIAS
from ...fluid.dygraph import Linear #DEFINE_ALIAS from ...fluid.dygraph import Linear #DEFINE_ALIAS
from ...fluid.dygraph import layers
from .. import functional as F
__all__ = [ __all__ = ['BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample']
'BilinearTensorProduct',
'Pool2D',
'Embedding', class UpSample(layers.Layer):
'Linear', """
# 'UpSample' This op resizes a batch of images.
] The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, height and width).
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the
future and only use :attr:`out_shape` instead.
Supporting resample methods:
'LINEAR' : linear interpolation
'BILINEAR' : Bilinear interpolation
'TRILINEAR' : Trilinear interpolation
'NEAREST' : Nearest neighbor interpolation
'BICUBIC' : Bicubic interpolation
Linear interpolation is the method of using a line connecting two known quantities
to determine the value of an unknown quantity between the two known quantities.
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
Bicubic interpolation is an extension of cubic interpolation for interpolating
data points on a two-dimensional regular grid. The interpolated surface is
smoother than corresponding surfaces obtained by bilinear interpolation or
nearest-neighbor interpolation.
Example:
.. code-block:: text
For scale:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Nearest neighbor interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = floor (H_{in} * scale_{factor})
W_out = floor (W_{in} * scale_{factor})
else:
align_corners = True
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = round(H_{in} * scale_{factor})
W_out = round(W_{in} * scale_{factor})
Linear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = W_{in} * scale_{factor}
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Bicubic interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Linear_interpolation.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation.
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation.
For details of bicubic interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bicubic_interpolation
Parameters:
input (Variable): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_w, ) when input is 3-D Tensor ,
the shape is (out_h, out_w) when input is a 4-D Tensor and is
(out_d, out_h, out_w) when input is a 5-D Tensor. Default: None. If
a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale(float|Variable|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
resample(str): The resample method. It supports 'LINEAR', 'BILINEAR', 'TRILINEAR' ,
'BICUBIC' and 'NEAREST' currently. Default: 'BILINEAR'
align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the
input and output tensors are aligned, preserving the values at the
corner pixels.
Default: True
align_mode(int) : An optional for bilinear interpolation. can be \'0\'
for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for
src_idx = scale*dst_index.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:'NCW', `"NCHW"`, `"NHWC"`, `"NCDHW"`,
`"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
Returns:
A 3-D Tensor of the shape (num_batches, channels, out_w) or (num_batches, out_w, channels),
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
Raises:
TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR',
'TRILINEAR', 'BICUBIC', or 'NEAREST' currently.
ValueError: 'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.
ValueError: 'TRILINEAR' only support 5-D tensor.
ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2 for input 4-D tensor.
ValueError: out_shape length should be 3 for input 5-D tensor.
ValueError: scale should be greater than zero.
TypeError: align_corners should be a bool value
ValueError: align_mode can only be '0' or '1'
ValueError: data_format can only be 'NCW', 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'.
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle.fluid.dygraph as dg
upsample_op = paddle.nn.UpSample(out_shape=[12,12])
input_data = np.random.rand(2,3,6,10).astype("float32")
place = paddle.fluid.CPUPlace()
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = upsample_op(input=input)
print(output.shape)
# [2L, 3L, 12L, 12L]
"""
def __init__(self,
out_shape=None,
scale=None,
resample='BILINEAR',
align_corners=True,
align_mode=1,
data_format='NCHW'):
super(UpSample, self).__init__()
self.out_shape = out_shape
self.scale = scale
self.resample = resample
self.align_corners = align_corners
self.align_mode = align_mode
self.data_format = data_format
def forward(self, input):
out = F.interpolate(
input,
out_shape=self.out_shape,
scale=self.scale,
resample=self.resample,
align_corners=self.align_corners,
align_mode=self.align_mode,
data_format=self.data_format)
return out

Loading…
Cancel
Save