You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/ut/python/dataset/test_datasets_voc.py

177 lines
5.7 KiB

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.dataset as ds
DATA_DIR = "../data/dataset/testVOC2012"
IMAGE_SHAPE = [2268, 2268, 2268, 2268, 642, 607, 561, 596, 612, 2268]
TARGET_SHAPE = [680, 680, 680, 680, 642, 607, 561, 596, 612, 680]
5 years ago
def test_voc_segmentation():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
num = 0
for item in data1.create_dict_iterator():
assert (item["image"].shape[0] == IMAGE_SHAPE[num])
assert (item["target"].shape[0] == TARGET_SHAPE[num])
num += 1
assert (num == 10)
5 years ago
def test_voc_detection():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
num = 0
5 years ago
count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator():
assert (item["image"].shape[0] == IMAGE_SHAPE[num])
for bbox in item["annotation"]:
count[bbox[0]] += 1
num += 1
assert (num == 9)
5 years ago
assert (count == [3, 2, 1, 2, 4, 3])
def test_voc_class_index():
5 years ago
class_index = {'car': 0, 'cat': 1, 'train': 5}
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", class_indexing=class_index, decode=True)
class_index1 = data1.get_class_indexing()
5 years ago
assert (class_index1 == {'car': 0, 'cat': 1, 'train': 5})
data1 = data1.shuffle(4)
class_index2 = data1.get_class_indexing()
5 years ago
assert (class_index2 == {'car': 0, 'cat': 1, 'train': 5})
num = 0
5 years ago
count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator():
for bbox in item["annotation"]:
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5)
count[bbox[0]] += 1
num += 1
assert (num == 6)
5 years ago
assert (count == [3, 2, 0, 0, 0, 3])
def test_voc_get_class_indexing():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
class_index1 = data1.get_class_indexing()
5 years ago
assert (class_index1 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
data1 = data1.shuffle(4)
class_index2 = data1.get_class_indexing()
5 years ago
assert (class_index2 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
num = 0
5 years ago
count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator():
for bbox in item["annotation"]:
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5)
count[bbox[0]] += 1
num += 1
assert (num == 9)
5 years ago
assert (count == [3, 2, 1, 2, 4, 3])
def test_case_0():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
resize_op = vision.Resize((224, 224))
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.map(input_columns=["target"], operations=resize_op)
repeat_num = 4
data1 = data1.repeat(repeat_num)
batch_size = 2
data1 = data1.batch(batch_size, drop_remainder=True)
num = 0
for item in data1.create_dict_iterator():
num += 1
assert (num == 20)
5 years ago
def test_case_1():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
resize_op = vision.Resize((224, 224))
data1 = data1.map(input_columns=["image"], operations=resize_op)
repeat_num = 4
data1 = data1.repeat(repeat_num)
batch_size = 2
data1 = data1.batch(batch_size, drop_remainder=True, pad_info={})
num = 0
for item in data1.create_dict_iterator():
num += 1
assert (num == 18)
5 years ago
def test_voc_exception():
try:
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True)
for _ in data1.create_dict_iterator():
pass
assert False
except ValueError:
pass
try:
5 years ago
data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", class_indexing={"cat": 0}, decode=True)
for _ in data2.create_dict_iterator():
pass
assert False
except ValueError:
pass
try:
data3 = ds.VOCDataset(DATA_DIR, task="Detection", mode="notexist", decode=True)
for _ in data3.create_dict_iterator():
pass
assert False
except ValueError:
pass
try:
data4 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnotexist", decode=True)
for _ in data4.create_dict_iterator():
pass
assert False
except RuntimeError:
pass
try:
data5 = ds.VOCDataset(DATA_DIR, task="Detection", mode="invalidxml", decode=True)
for _ in data5.create_dict_iterator():
pass
assert False
except RuntimeError:
pass
try:
data6 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnoobject", decode=True)
for _ in data6.create_dict_iterator():
pass
assert False
except RuntimeError:
pass
5 years ago
if __name__ == '__main__':
test_voc_segmentation()
test_voc_detection()
test_voc_class_index()
test_voc_get_class_indexing()
test_case_0()
test_case_1()
test_voc_exception()