|
|
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
import mindspore.dataset.transforms.vision.c_transforms as vision
|
|
|
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
|
|
|
|
DATA_DIR = "../data/dataset/testVOC2012"
|
|
|
|
IMAGE_SHAPE = [2268, 2268, 2268, 2268, 642, 607, 561, 596, 612, 2268]
|
|
|
|
TARGET_SHAPE = [680, 680, 680, 680, 642, 607, 561, 596, 612, 680]
|
|
|
|
|
|
|
|
|
|
|
|
def test_voc_segmentation():
|
|
|
|
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
|
|
|
num = 0
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
assert (item["image"].shape[0] == IMAGE_SHAPE[num])
|
|
|
|
assert (item["target"].shape[0] == TARGET_SHAPE[num])
|
|
|
|
num += 1
|
|
|
|
assert (num == 10)
|
|
|
|
|
|
|
|
|
|
|
|
def test_voc_detection():
|
|
|
|
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
|
|
|
|
num = 0
|
|
|
|
count = [0, 0, 0, 0, 0, 0]
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
assert (item["image"].shape[0] == IMAGE_SHAPE[num])
|
|
|
|
for bbox in item["annotation"]:
|
|
|
|
count[bbox[0]] += 1
|
|
|
|
num += 1
|
|
|
|
assert (num == 9)
|
|
|
|
assert (count == [3, 2, 1, 2, 4, 3])
|
|
|
|
|
|
|
|
|
|
|
|
def test_voc_class_index():
|
|
|
|
class_index = {'car': 0, 'cat': 1, 'train': 5}
|
|
|
|
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", class_indexing=class_index, decode=True)
|
|
|
|
class_index1 = data1.get_class_indexing()
|
|
|
|
assert (class_index1 == {'car': 0, 'cat': 1, 'train': 5})
|
|
|
|
data1 = data1.shuffle(4)
|
|
|
|
class_index2 = data1.get_class_indexing()
|
|
|
|
assert (class_index2 == {'car': 0, 'cat': 1, 'train': 5})
|
|
|
|
num = 0
|
|
|
|
count = [0, 0, 0, 0, 0, 0]
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
for bbox in item["annotation"]:
|
|
|
|
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5)
|
|
|
|
count[bbox[0]] += 1
|
|
|
|
num += 1
|
|
|
|
assert (num == 6)
|
|
|
|
assert (count == [3, 2, 0, 0, 0, 3])
|
|
|
|
|
|
|
|
|
|
|
|
def test_voc_get_class_indexing():
|
|
|
|
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
|
|
|
|
class_index1 = data1.get_class_indexing()
|
|
|
|
assert (class_index1 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
|
|
|
|
data1 = data1.shuffle(4)
|
|
|
|
class_index2 = data1.get_class_indexing()
|
|
|
|
assert (class_index2 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
|
|
|
|
num = 0
|
|
|
|
count = [0, 0, 0, 0, 0, 0]
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
for bbox in item["annotation"]:
|
|
|
|
assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5)
|
|
|
|
count[bbox[0]] += 1
|
|
|
|
num += 1
|
|
|
|
assert (num == 9)
|
|
|
|
assert (count == [3, 2, 1, 2, 4, 3])
|
|
|
|
|
|
|
|
|
|
|
|
def test_case_0():
|
|
|
|
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
|
|
|
|
|
|
|
|
resize_op = vision.Resize((224, 224))
|
|
|
|
|
|
|
|
data1 = data1.map(input_columns=["image"], operations=resize_op)
|
|
|
|
data1 = data1.map(input_columns=["target"], operations=resize_op)
|
|
|
|
repeat_num = 4
|
|
|
|
data1 = data1.repeat(repeat_num)
|
|
|
|
batch_size = 2
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True)
|
|
|
|
|
|
|
|
num = 0
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
num += 1
|
|
|
|
assert (num == 20)
|
|
|
|
|
|
|
|
|
|
|
|
def test_case_1():
|
|
|
|
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
|
|
|
|
|
|
|
|
resize_op = vision.Resize((224, 224))
|
|
|
|
|
|
|
|
data1 = data1.map(input_columns=["image"], operations=resize_op)
|
|
|
|
repeat_num = 4
|
|
|
|
data1 = data1.repeat(repeat_num)
|
|
|
|
batch_size = 2
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True, pad_info={})
|
|
|
|
|
|
|
|
num = 0
|
|
|
|
for item in data1.create_dict_iterator():
|
|
|
|
num += 1
|
|
|
|
assert (num == 18)
|
|
|
|
|
|
|
|
|
|
|
|
def test_voc_exception():
|
|
|
|
try:
|
|
|
|
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True)
|
|
|
|
for _ in data1.create_dict_iterator():
|
|
|
|
pass
|
|
|
|
assert False
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
try:
|
|
|
|
data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", class_indexing={"cat": 0}, decode=True)
|
|
|
|
for _ in data2.create_dict_iterator():
|
|
|
|
pass
|
|
|
|
assert False
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
try:
|
|
|
|
data3 = ds.VOCDataset(DATA_DIR, task="Detection", mode="notexist", decode=True)
|
|
|
|
for _ in data3.create_dict_iterator():
|
|
|
|
pass
|
|
|
|
assert False
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
try:
|
|
|
|
data4 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnotexist", decode=True)
|
|
|
|
for _ in data4.create_dict_iterator():
|
|
|
|
pass
|
|
|
|
assert False
|
|
|
|
except RuntimeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
try:
|
|
|
|
data5 = ds.VOCDataset(DATA_DIR, task="Detection", mode="invalidxml", decode=True)
|
|
|
|
for _ in data5.create_dict_iterator():
|
|
|
|
pass
|
|
|
|
assert False
|
|
|
|
except RuntimeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
try:
|
|
|
|
data6 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnoobject", decode=True)
|
|
|
|
for _ in data6.create_dict_iterator():
|
|
|
|
pass
|
|
|
|
assert False
|
|
|
|
except RuntimeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_voc_segmentation()
|
|
|
|
test_voc_detection()
|
|
|
|
test_voc_class_index()
|
|
|
|
test_voc_get_class_indexing()
|
|
|
|
test_case_0()
|
|
|
|
test_case_1()
|
|
|
|
test_voc_exception()
|