Add some transform apis (#25357)

* add more vision transfrom apis
fix_copy_if_different
LielinJiang 5 years ago committed by GitHub
parent 417b243968
commit 8dea7bed2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,6 +23,7 @@ import numpy as np
from paddle.incubate.hapi.datasets import DatasetFolder from paddle.incubate.hapi.datasets import DatasetFolder
from paddle.incubate.hapi.vision.transforms import transforms from paddle.incubate.hapi.vision.transforms import transforms
import paddle.incubate.hapi.vision.transforms.functional as F
class TestTransforms(unittest.TestCase): class TestTransforms(unittest.TestCase):
@ -100,6 +101,78 @@ class TestTransforms(unittest.TestCase):
]) ])
self.do_transform(trans) self.do_transform(trans)
def test_rotate(self):
trans = transforms.Compose([
transforms.RandomRotate(90),
transforms.RandomRotate([-10, 10]),
transforms.RandomRotate(
45, expand=True),
transforms.RandomRotate(
10, expand=True, center=(60, 80)),
])
self.do_transform(trans)
def test_pad(self):
trans = transforms.Compose([transforms.Pad(2)])
self.do_transform(trans)
fake_img = np.random.rand(200, 150, 3).astype('float32')
trans_pad = transforms.Pad(10)
fake_img_padded = trans_pad(fake_img)
np.testing.assert_equal(fake_img_padded.shape, (220, 170, 3))
trans_pad1 = transforms.Pad([1, 2])
trans_pad2 = transforms.Pad([1, 2, 3, 4])
img = trans_pad1(fake_img)
img = trans_pad2(img)
def test_erase(self):
trans = transforms.Compose(
[transforms.RandomErasing(), transforms.RandomErasing(value=0.0)])
self.do_transform(trans)
def test_random_crop(self):
trans = transforms.Compose([
transforms.RandomCrop(200),
transforms.RandomCrop((140, 160)),
])
self.do_transform(trans)
trans_random_crop1 = transforms.RandomCrop(224)
trans_random_crop2 = transforms.RandomCrop((140, 160))
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_crop1 = trans_random_crop1(fake_img)
fake_img_crop2 = trans_random_crop2(fake_img_crop1)
np.testing.assert_equal(fake_img_crop1.shape, (224, 224, 3))
np.testing.assert_equal(fake_img_crop2.shape, (140, 160, 3))
trans_random_crop_same = transforms.RandomCrop((140, 160))
img = trans_random_crop_same(fake_img_crop2)
trans_random_crop_bigger = transforms.RandomCrop((180, 200))
img = trans_random_crop_bigger(img)
trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True)
img = trans_random_crop_pad(img)
def test_grayscale(self):
trans = transforms.Compose([transforms.Grayscale()])
self.do_transform(trans)
trans_gray = transforms.Grayscale()
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_gray = trans_gray(fake_img)
np.testing.assert_equal(len(fake_img_gray.shape), 2)
np.testing.assert_equal(fake_img_gray.shape[0], 500)
np.testing.assert_equal(fake_img_gray.shape[1], 400)
trans_gray3 = transforms.Grayscale(3)
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_gray = trans_gray3(fake_img)
def test_exception(self): def test_exception(self):
trans = transforms.Compose([transforms.Resize(-1)]) trans = transforms.Compose([transforms.Resize(-1)])
@ -123,6 +196,36 @@ class TestTransforms(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
transforms.BrightnessTransform(-1.0) transforms.BrightnessTransform(-1.0)
with self.assertRaises(ValueError):
transforms.Pad([1.0, 2.0, 3.0])
with self.assertRaises(TypeError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, '1')
with self.assertRaises(TypeError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, 1, {})
with self.assertRaises(TypeError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, 1, padding_mode=-1)
with self.assertRaises(ValueError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, [1.0, 2.0, 3.0])
with self.assertRaises(ValueError):
transforms.RandomRotate(-2)
with self.assertRaises(ValueError):
transforms.RandomRotate([1, 2, 3])
with self.assertRaises(ValueError):
trans_gray = transforms.Grayscale(5)
fake_img = np.random.rand(100, 120, 3).astype('float32')
trans_gray(fake_img)
def test_info(self): def test_info(self):
str(transforms.Compose([transforms.Resize((224, 224))])) str(transforms.Compose([transforms.Resize((224, 224))]))
str(transforms.BatchCompose([transforms.Resize((224, 224))])) str(transforms.BatchCompose([transforms.Resize((224, 224))]))

@ -15,8 +15,10 @@
import sys import sys
import collections import collections
import random import random
import math
import cv2 import cv2
import numbers
import numpy as np import numpy as np
if sys.version_info < (3, 3): if sys.version_info < (3, 3):
@ -26,7 +28,7 @@ else:
Sequence = collections.abc.Sequence Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable Iterable = collections.abc.Iterable
__all__ = ['flip', 'resize'] __all__ = ['flip', 'resize', 'pad', 'rotate', 'to_grayscale']
def flip(image, code): def flip(image, code):
@ -99,3 +101,202 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR):
return cv2.resize(img, (ow, oh), interpolation=interpolation) return cv2.resize(img, (ow, oh), interpolation=interpolation)
else: else:
return cv2.resize(img, size[::-1], interpolation=interpolation) return cv2.resize(img, size[::-1], interpolation=interpolation)
def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'):
"""Pads the given CV Image on all sides with speficified padding mode and fill value.
Args:
img (np.ndarray): Image to be padded.
padding (int|tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill (int|tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
``constant`` means padding with a constant value, this value is specified with fill.
``edge`` means padding with the last value at the edge of the image.
``reflect`` means padding with reflection of image (without repeating the last value on the edge)
padding ``[1, 2, 3, 4]`` with 2 elements on both sides in reflect mode
will result in ``[3, 2, 1, 2, 3, 4, 3, 2]``.
``symmetric`` menas pads with reflection of image (repeating the last value on the edge)
padding ``[1, 2, 3, 4]`` with 2 elements on both sides in symmetric mode
will result in ``[2, 1, 1, 2, 3, 4, 4, 3]``.
Returns:
numpy ndarray: Padded image.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms.functional import pad
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = pad(fake_img, 2)
print(fake_img.shape)
"""
if not isinstance(padding, (numbers.Number, list, tuple)):
raise TypeError('Got inappropriate padding arg')
if not isinstance(fill, (numbers.Number, str, list, tuple)):
raise TypeError('Got inappropriate fill arg')
if not isinstance(padding_mode, str):
raise TypeError('Got inappropriate padding_mode arg')
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError(
"Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
'Expected padding mode be either constant, edge, reflect or symmetric, but got {}'.format(padding_mode)
PAD_MOD = {
'constant': cv2.BORDER_CONSTANT,
'edge': cv2.BORDER_REPLICATE,
'reflect': cv2.BORDER_DEFAULT,
'symmetric': cv2.BORDER_REFLECT
}
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, collections.Sequence) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, collections.Sequence) and len(padding) == 4:
pad_left, pad_top, pad_right, pad_bottom = padding
if isinstance(fill, numbers.Number):
fill = (fill, ) * (2 * len(img.shape) - 3)
if padding_mode == 'constant':
assert (len(fill) == 3 and len(img.shape) == 3) or (len(fill) == 1 and len(img.shape) == 2), \
'channel of image is {} but length of fill is {}'.format(img.shape[-1], len(fill))
img = cv2.copyMakeBorder(
src=img,
top=pad_top,
bottom=pad_bottom,
left=pad_left,
right=pad_right,
borderType=PAD_MOD[padding_mode],
value=fill)
return img
def rotate(img,
angle,
interpolation=cv2.INTER_LINEAR,
expand=False,
center=None):
"""Rotates the image by angle.
Args:
img (numpy.ndarray): Image to be rotated.
angle (float|int): In degrees clockwise order.
interpolation (int, optional):
interpolation: Interpolation method.
expand (bool|optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (2-tuple|optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
Returns:
numpy ndarray: Rotated image.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms.functional import rotate
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = rotate(fake_img, 10)
print(fake_img.shape)
"""
dtype = img.dtype
h, w, _ = img.shape
point = center or (w / 2, h / 2)
M = cv2.getRotationMatrix2D(point, angle=-angle, scale=1)
if expand:
if center is None:
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
M[0, 2] += (nW / 2) - point[0]
M[1, 2] += (nH / 2) - point[1]
dst = cv2.warpAffine(img, M, (nW, nH))
else:
xx = []
yy = []
for point in (np.array([0, 0, 1]), np.array([w - 1, 0, 1]),
np.array([w - 1, h - 1, 1]), np.array([0, h - 1, 1])):
target = np.dot(M, point)
xx.append(target[0])
yy.append(target[1])
nh = int(math.ceil(max(yy)) - math.floor(min(yy)))
nw = int(math.ceil(max(xx)) - math.floor(min(xx)))
M[0, 2] += (nw - w) / 2
M[1, 2] += (nh - h) / 2
dst = cv2.warpAffine(img, M, (nw, nh), flags=interpolation)
else:
dst = cv2.warpAffine(img, M, (w, h), flags=interpolation)
return dst.astype(dtype)
def to_grayscale(img, num_output_channels=1):
"""Converts image to grayscale version of image.
Args:
img (numpy.ndarray): Image to be converted to grayscale.
Returns:
numpy.ndarray: Grayscale version of the image.
if num_output_channels == 1, returned image is single channel
if num_output_channels == 3, returned image is 3 channel with r == g == b
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms.functional import to_grayscale
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = to_grayscale(fake_img)
print(fake_img.shape)
"""
if num_output_channels == 1:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
elif num_output_channels == 3:
img = cv2.cvtColor(
cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
else:
raise ValueError('num_output_channels should be either 1 or 3')
return img

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save