You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/ut/python/dataset/test_datasets_voc.py

194 lines
6.2 KiB

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
DATA_DIR = "../data/dataset/testVOC2012"
IMAGE_SHAPE = [2268, 2268, 2268, 2268, 642, 607, 561, 596, 612, 2268]
TARGET_SHAPE = [680, 680, 680, 680, 642, 607, 561, 596, 612, 680]
5 years ago
def test_voc_segmentation():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
num = 0
for item in data1.create_dict_iterator(num_epochs=1):
assert item["image"].shape[0] == IMAGE_SHAPE[num]
assert item["target"].shape[0] == TARGET_SHAPE[num]
num += 1
assert num == 10
5 years ago
def test_voc_detection():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
num = 0
5 years ago
count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator(num_epochs=1):
assert item["image"].shape[0] == IMAGE_SHAPE[num]
for label in item["label"]:
count[label[0]] += 1
num += 1
assert num == 9
assert count == [3, 2, 1, 2, 4, 3]
5 years ago
def test_voc_class_index():
5 years ago
class_index = {'car': 0, 'cat': 1, 'train': 5}
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True)
class_index1 = data1.get_class_indexing()
5 years ago
assert (class_index1 == {'car': 0, 'cat': 1, 'train': 5})
data1 = data1.shuffle(4)
class_index2 = data1.get_class_indexing()
5 years ago
assert (class_index2 == {'car': 0, 'cat': 1, 'train': 5})
num = 0
5 years ago
count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator(num_epochs=1):
for label in item["label"]:
count[label[0]] += 1
assert label[0] in (0, 1, 5)
num += 1
assert num == 6
assert count == [3, 2, 0, 0, 0, 3]
5 years ago
def test_voc_get_class_indexing():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True)
class_index1 = data1.get_class_indexing()
5 years ago
assert (class_index1 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
data1 = data1.shuffle(4)
class_index2 = data1.get_class_indexing()
5 years ago
assert (class_index2 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
num = 0
5 years ago
count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator(num_epochs=1):
for label in item["label"]:
count[label[0]] += 1
assert label[0] in (0, 1, 2, 3, 4, 5)
num += 1
assert num == 9
assert count == [3, 2, 1, 2, 4, 3]
5 years ago
def test_case_0():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True)
resize_op = vision.Resize((224, 224))
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.map(input_columns=["target"], operations=resize_op)
repeat_num = 4
data1 = data1.repeat(repeat_num)
batch_size = 2
data1 = data1.batch(batch_size, drop_remainder=True)
num = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num += 1
assert num == 20
5 years ago
def test_case_1():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True)
resize_op = vision.Resize((224, 224))
data1 = data1.map(input_columns=["image"], operations=resize_op)
repeat_num = 4
data1 = data1.repeat(repeat_num)
batch_size = 2
data1 = data1.batch(batch_size, drop_remainder=True, pad_info={})
num = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num += 1
assert num == 18
5 years ago
def test_case_2():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True)
sizes = [0.5, 0.5]
randomize = False
dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize)
num_iter = 0
for _ in dataset1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 5
num_iter = 0
for _ in dataset2.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 5
def test_voc_exception():
try:
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", usage="train", decode=True)
for _ in data1.create_dict_iterator(num_epochs=1):
pass
assert False
except ValueError:
pass
try:
data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", class_indexing={"cat": 0}, decode=True)
for _ in data2.create_dict_iterator(num_epochs=1):
pass
assert False
except ValueError:
pass
try:
data3 = ds.VOCDataset(DATA_DIR, task="Detection", usage="notexist", decode=True)
for _ in data3.create_dict_iterator(num_epochs=1):
pass
assert False
except ValueError:
pass
try:
data4 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnotexist", decode=True)
for _ in data4.create_dict_iterator(num_epochs=1):
pass
assert False
except RuntimeError:
pass
try:
data5 = ds.VOCDataset(DATA_DIR, task="Detection", usage="invalidxml", decode=True)
for _ in data5.create_dict_iterator(num_epochs=1):
pass
assert False
except RuntimeError:
pass
try:
data6 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnoobject", decode=True)
for _ in data6.create_dict_iterator(num_epochs=1):
pass
assert False
except RuntimeError:
pass
5 years ago
if __name__ == '__main__':
test_voc_segmentation()
test_voc_detection()
test_voc_class_index()
test_voc_get_class_indexing()
test_case_0()
test_case_1()
test_case_2()
test_voc_exception()