parent
91e9a25e0b
commit
338dd13542
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.voc_seg
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestVOC(unittest.TestCase):
|
||||||
|
def check_reader(self, reader):
|
||||||
|
sum = 0
|
||||||
|
label = 0
|
||||||
|
for l in reader():
|
||||||
|
self.assertEqual(l[0].size, l[1].size)
|
||||||
|
sum += 1
|
||||||
|
return sum
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
count = self.check_reader(paddle.v2.dataset.voc_seg.train())
|
||||||
|
self.assertEqual(count, 2913)
|
||||||
|
|
||||||
|
def test_test(self):
|
||||||
|
count = self.check_reader(paddle.v2.dataset.voc_seg.test())
|
||||||
|
self.assertEqual(count, 1464)
|
||||||
|
|
||||||
|
def test_val(self):
|
||||||
|
count = self.check_reader(paddle.v2.dataset.voc_seg.val())
|
||||||
|
self.assertEqual(count, 1449)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,74 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Image dataset for segmentation.
|
||||||
|
The 2012 dataset contains images from 2008-2011 for which additional segmentations have been prepared. As in previous years the assignment to training/test sets has been maintained. The total number of images with segmentation has been increased from 7,062 to 9,993.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tarfile
|
||||||
|
import numpy as np
|
||||||
|
from common import download
|
||||||
|
from paddle.v2.image import *
|
||||||
|
|
||||||
|
__all__ = ['train', 'test', 'val']
|
||||||
|
|
||||||
|
VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar'
|
||||||
|
VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
|
||||||
|
SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
|
||||||
|
DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
|
||||||
|
LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png'
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(filename, sub_name):
|
||||||
|
|
||||||
|
tarobject = tarfile.open(filename)
|
||||||
|
name2mem = {}
|
||||||
|
for ele in tarobject.getmembers():
|
||||||
|
name2mem[ele.name] = ele
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
set_file = SET_FILE.format(sub_name)
|
||||||
|
sets = tarobject.extractfile(name2mem[set_file])
|
||||||
|
for line in sets:
|
||||||
|
line = line.strip()
|
||||||
|
data_file = DATA_FILE.format(line)
|
||||||
|
label_file = LABEL_FILE.format(line)
|
||||||
|
data = tarobject.extractfile(name2mem[data_file]).read()
|
||||||
|
label = tarobject.extractfile(name2mem[label_file]).read()
|
||||||
|
data = load_image_bytes(data)
|
||||||
|
label = load_image_bytes(label)
|
||||||
|
yield data, label
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""
|
||||||
|
Create a train dataset reader containing 2913 images.
|
||||||
|
"""
|
||||||
|
return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'trainval')
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
Create a test dataset reader containing 1464 images.
|
||||||
|
"""
|
||||||
|
return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'train')
|
||||||
|
|
||||||
|
|
||||||
|
def val():
|
||||||
|
"""
|
||||||
|
Create a val dataset reader containing 1449 images.
|
||||||
|
"""
|
||||||
|
return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'val')
|
Loading…
Reference in new issue