You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
4.5 KiB
163 lines
4.5 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from PIL import Image
|
|
from paddle.utils import try_import
|
|
|
|
__all__ = ['set_image_backend', 'get_image_backend', 'image_load']
|
|
|
|
_image_backend = 'pil'
|
|
|
|
|
|
def set_image_backend(backend):
|
|
"""
|
|
Specifies the backend used to load images in class ``paddle.vision.datasets.ImageFolder``
|
|
and ``paddle.vision.datasets.DatasetFolder`` . Now support backends are pillow and opencv.
|
|
If backend not set, will use 'pil' as default.
|
|
|
|
Args:
|
|
backend (str): Name of the image load backend, should be one of {'pil', 'cv2'}.
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from paddle.vision import DatasetFolder
|
|
from paddle.vision import set_image_backend
|
|
|
|
set_image_backend('pil')
|
|
|
|
def make_fake_dir():
|
|
data_dir = tempfile.mkdtemp()
|
|
|
|
for i in range(2):
|
|
sub_dir = os.path.join(data_dir, 'class_' + str(i))
|
|
if not os.path.exists(sub_dir):
|
|
os.makedirs(sub_dir)
|
|
for j in range(2):
|
|
fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
|
|
fake_img.save(os.path.join(sub_dir, str(j) + '.png'))
|
|
return data_dir
|
|
|
|
temp_dir = make_fake_dir()
|
|
|
|
pil_data_folder = DatasetFolder(temp_dir)
|
|
|
|
for items in pil_data_folder:
|
|
break
|
|
|
|
# should get PIL.Image.Image
|
|
print(type(items[0]))
|
|
|
|
# use opencv as backend
|
|
# set_image_backend('cv2')
|
|
|
|
# cv2_data_folder = DatasetFolder(temp_dir)
|
|
|
|
# for items in cv2_data_folder:
|
|
# break
|
|
|
|
# should get numpy.ndarray
|
|
# print(type(items[0]))
|
|
|
|
shutil.rmtree(temp_dir)
|
|
"""
|
|
global _image_backend
|
|
if backend not in ['pil', 'cv2']:
|
|
raise ValueError(
|
|
"Expected backend are one of ['pil', 'cv2'], but got {}"
|
|
.format(backend))
|
|
_image_backend = backend
|
|
|
|
|
|
def get_image_backend():
|
|
"""
|
|
Gets the name of the package used to load images
|
|
|
|
Returns:
|
|
str: backend of image load.
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
from paddle.vision import get_image_backend
|
|
|
|
backend = get_image_backend()
|
|
print(backend)
|
|
|
|
"""
|
|
return _image_backend
|
|
|
|
|
|
def image_load(path, backend=None):
|
|
"""Load an image.
|
|
|
|
Args:
|
|
path (str): Path of the image.
|
|
backend (str, optional): The image decoding backend type. Options are
|
|
`cv2`, `pil`, `None`. If backend is None, the global _imread_backend
|
|
specified by ``paddle.vision.set_image_backend`` will be used. Default: None.
|
|
|
|
Returns:
|
|
PIL.Image or np.array: Loaded image.
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from paddle.vision import image_load, set_image_backend
|
|
|
|
fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
|
|
|
|
path = 'temp.png'
|
|
fake_img.save(path)
|
|
|
|
set_image_backend('pil')
|
|
|
|
pil_img = image_load(path).convert('RGB')
|
|
|
|
# should be PIL.Image.Image
|
|
print(type(pil_img))
|
|
|
|
# use opencv as backend
|
|
# set_image_backend('cv2')
|
|
|
|
# np_img = image_load(path)
|
|
# # should get numpy.ndarray
|
|
# print(type(np_img))
|
|
|
|
"""
|
|
|
|
if backend is None:
|
|
backend = _image_backend
|
|
if backend not in ['pil', 'cv2']:
|
|
raise ValueError(
|
|
"Expected backend are one of ['pil', 'cv2'], but got {}"
|
|
.format(backend))
|
|
|
|
if backend == 'pil':
|
|
return Image.open(path)
|
|
else:
|
|
cv2 = try_import('cv2')
|
|
return cv2.imread(path)
|