Use PIL to read image in palette mode

cblas_new
wanghaoshuang 8 years ago
parent c4f301ded7
commit 1ba879bead

@ -21,7 +21,7 @@ class TestVOC(unittest.TestCase):
sum = 0 sum = 0
label = 0 label = 0
for l in reader(): for l in reader():
self.assertEqual(l[0].size, l[1].size) self.assertEqual(l[0].size, 3 * l[1].size)
sum += 1 sum += 1
return sum return sum

@ -20,14 +20,16 @@ with segmentation has been increased from 7,062 to 9,993.
""" """
import tarfile import tarfile
import io
import numpy as np import numpy as np
from common import download from common import download
from paddle.v2.image import * from paddle.v2.image import *
from PIL import Image
__all__ = ['train', 'test', 'val'] __all__ = ['train', 'test', 'val']
VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\ VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\
VOCtrainval_11-May-2012.tar' VOCtrainval_11-May-2012.tar'
VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd' VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt' SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
@ -51,8 +53,10 @@ def reader_creator(filename, sub_name):
label_file = LABEL_FILE.format(line) label_file = LABEL_FILE.format(line)
data = tarobject.extractfile(name2mem[data_file]).read() data = tarobject.extractfile(name2mem[data_file]).read()
label = tarobject.extractfile(name2mem[label_file]).read() label = tarobject.extractfile(name2mem[label_file]).read()
data = load_image_bytes(data) data = Image.open(io.BytesIO(data))
label = load_image_bytes(label) label = Image.open(io.BytesIO(label))
data = np.array(data)
label = np.array(label)
yield data, label yield data, label
return reader return reader

Loading…
Cancel
Save