Add pil backend for vision transforms (#28035)

* add pil backend
swt-req
LielinJiang 4 years ago committed by GitHub
parent 135b62a4ec
commit 74c8a81127
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -105,7 +105,7 @@ class TestCallbacks(unittest.TestCase):
self.run_callback()
def test_visualdl_callback(self):
# visualdl not support python3
# visualdl not support python2
if sys.version_info < (3, ):
return

File diff suppressed because it is too large Load Diff

@ -21,6 +21,10 @@ from .transforms import *
from . import datasets
from .datasets import *
from . import image
from .image import *
__all__ = models.__all__ \
+ transforms.__all__ \
+ datasets.__all__
+ datasets.__all__ \
+ image.__all__

@ -14,6 +14,7 @@
import os
import sys
from PIL import Image
import paddle
from paddle.io import Dataset
@ -136,7 +137,7 @@ class DatasetFolder(Dataset):
"Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = cv2_loader if loader is None else loader
self.loader = default_loader if loader is None else loader
self.extensions = extensions
self.classes = classes
@ -193,9 +194,23 @@ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def cv2_loader(path):
cv2 = try_import('cv2')
return cv2.imread(path)
return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
def default_loader(path):
from paddle.vision import get_image_backend
if get_image_backend() == 'cv2':
return cv2_loader(path)
else:
return pil_loader(path)
class ImageFolder(Dataset):
@ -280,7 +295,7 @@ class ImageFolder(Dataset):
"Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = cv2_loader if loader is None else loader
self.loader = default_loader if loader is None else loader
self.extensions = extensions
self.samples = samples
self.transform = transform

@ -0,0 +1,162 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
from paddle.utils import try_import
__all__ = ['set_image_backend', 'get_image_backend', 'image_load']
_image_backend = 'pil'
def set_image_backend(backend):
"""
Specifies the backend used to load images in class ``paddle.vision.datasets.ImageFolder``
and ``paddle.vision.datasets.DatasetFolder`` . Now support backends are pillow and opencv.
If backend not set, will use 'pil' as default.
Args:
backend (str): Name of the image load backend, should be one of {'pil', 'cv2'}.
Examples:
.. code-block:: python
import os
import shutil
import tempfile
import numpy as np
from PIL import Image
from paddle.vision import DatasetFolder
from paddle.vision import set_image_backend
set_image_backend('pil')
def make_fake_dir():
data_dir = tempfile.mkdtemp()
for i in range(2):
sub_dir = os.path.join(data_dir, 'class_' + str(i))
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
for j in range(2):
fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
fake_img.save(os.path.join(sub_dir, str(j) + '.png'))
return data_dir
temp_dir = make_fake_dir()
pil_data_folder = DatasetFolder(temp_dir)
for items in pil_data_folder:
break
# should get PIL.Image.Image
print(type(items[0]))
# use opencv as backend
# set_image_backend('cv2')
# cv2_data_folder = DatasetFolder(temp_dir)
# for items in cv2_data_folder:
# break
# should get numpy.ndarray
# print(type(items[0]))
shutil.rmtree(temp_dir)
"""
global _image_backend
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
_image_backend = backend
def get_image_backend():
"""
Gets the name of the package used to load images
Returns:
str: backend of image load.
Examples:
.. code-block:: python
from paddle.vision import get_image_backend
backend = get_image_backend()
print(backend)
"""
return _image_backend
def image_load(path, backend=None):
"""Load an image.
Args:
path (str): Path of the image.
backend (str, optional): The image decoding backend type. Options are
`cv2`, `pil`, `None`. If backend is None, the global _imread_backend
specified by ``paddle.vision.set_image_backend`` will be used. Default: None.
Returns:
PIL.Image or np.array: Loaded image.
Examples:
.. code-block:: python
import numpy as np
from PIL import Image
from paddle.vision import image_load, set_image_backend
fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
path = 'temp.png'
fake_img.save(path)
set_image_backend('pil')
pil_img = image_load(path).convert('RGB')
# should be PIL.Image.Image
print(type(pil_img))
# use opencv as backend
# set_image_backend('cv2')
# np_img = image_load(path)
# # should get numpy.ndarray
# print(type(np_img))
"""
if backend is None:
backend = _image_backend
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
if backend == 'pil':
return Image.open(path)
else:
cv2 = try_import('cv2')
return cv2.imread(path)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,40 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import paddle
def normalize(img, mean, std, data_format='CHW'):
"""Normalizes a tensor image with mean and standard deviation.
Args:
img (paddle.Tensor): input data to be normalized.
mean (list|tuple): Sequence of means for each channel.
std (list|tuple): Sequence of standard deviations for each channel.
data_format (str, optional): Data format of img, should be 'HWC' or
'CHW'. Default: 'CHW'.
Returns:
Tensor: Normalized mage.
"""
if data_format == 'CHW':
mean = paddle.to_tensor(mean).reshape([-1, 1, 1])
std = paddle.to_tensor(std).reshape([-1, 1, 1])
else:
mean = paddle.to_tensor(mean)
std = paddle.to_tensor(std)
return (img - mean) / std

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save