You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
6.9 KiB
193 lines
6.9 KiB
# Copyright 2020-2021 Huawei Technologies Co., Ltd.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import mindspore.dataset as ds
|
|
import mindspore.dataset.vision.c_transforms as vision
|
|
from mindspore import log as logger
|
|
from mindspore.dataset.vision import Inter
|
|
|
|
DATA_DIR = "../data/dataset/testCelebAData/"
|
|
|
|
|
|
def test_celeba_dataset_label():
|
|
"""
|
|
Test CelebA dataset with labels
|
|
"""
|
|
logger.info("Test CelebA labels")
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
|
expect_labels = [
|
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
|
0, 0, 1],
|
|
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
|
0, 0, 1],
|
|
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
|
0, 0, 1],
|
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
|
0, 0, 1]]
|
|
count = 0
|
|
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
logger.info("----------image--------")
|
|
logger.info(item["image"])
|
|
logger.info("----------attr--------")
|
|
logger.info(item["attr"])
|
|
for index in range(len(expect_labels[count])):
|
|
assert item["attr"][index] == expect_labels[count][index]
|
|
count = count + 1
|
|
assert count == 4
|
|
|
|
|
|
def test_celeba_dataset_op():
|
|
"""
|
|
Test CelebA dataset with decode
|
|
"""
|
|
logger.info("Test CelebA with decode")
|
|
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
|
|
crop_size = (80, 80)
|
|
resize_size = (24, 24)
|
|
# define map operations
|
|
data = data.repeat(2)
|
|
center_crop = vision.CenterCrop(crop_size)
|
|
resize_op = vision.Resize(resize_size, Inter.LINEAR) # Bilinear mode
|
|
data = data.map(operations=center_crop, input_columns=["image"])
|
|
data = data.map(operations=resize_op, input_columns=["image"])
|
|
|
|
count = 0
|
|
for item in data.create_dict_iterator(num_epochs=1):
|
|
logger.info("----------image--------")
|
|
logger.info(item["image"])
|
|
count = count + 1
|
|
assert count == 8
|
|
|
|
|
|
def test_celeba_dataset_ext():
|
|
"""
|
|
Test CelebA dataset with extension
|
|
"""
|
|
logger.info("Test CelebA extension option")
|
|
ext = [".JPEG"]
|
|
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
|
|
expect_labels = [
|
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
|
0, 1, 0, 1, 0, 0, 1],
|
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
|
0, 1, 0, 1, 0, 0, 1]]
|
|
count = 0
|
|
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
logger.info("----------image--------")
|
|
logger.info(item["image"])
|
|
logger.info("----------attr--------")
|
|
logger.info(item["attr"])
|
|
for index in range(len(expect_labels[count])):
|
|
assert item["attr"][index] == expect_labels[count][index]
|
|
count = count + 1
|
|
assert count == 2
|
|
|
|
|
|
def test_celeba_dataset_distribute():
|
|
"""
|
|
Test CelebA dataset with distributed options
|
|
"""
|
|
logger.info("Test CelebA with sharding")
|
|
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
|
|
count = 0
|
|
for item in data.create_dict_iterator(num_epochs=1):
|
|
logger.info("----------image--------")
|
|
logger.info(item["image"])
|
|
logger.info("----------attr--------")
|
|
logger.info(item["attr"])
|
|
count = count + 1
|
|
assert count == 2
|
|
|
|
|
|
def test_celeba_get_dataset_size():
|
|
"""
|
|
Test CelebA dataset get dataset size
|
|
"""
|
|
logger.info("Test CelebA get dataset size")
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
|
size = data.get_dataset_size()
|
|
assert size == 4
|
|
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
|
|
size = data.get_dataset_size()
|
|
assert size == 2
|
|
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
|
|
size = data.get_dataset_size()
|
|
assert size == 1
|
|
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
|
|
size = data.get_dataset_size()
|
|
assert size == 1
|
|
|
|
|
|
def test_celeba_dataset_exception_file_path():
|
|
"""
|
|
Test CelebA dataset with bad file path
|
|
"""
|
|
logger.info("Test CelebA with bad file path")
|
|
def exception_func(item):
|
|
raise Exception("Error occur!")
|
|
|
|
try:
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False)
|
|
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
|
for _ in data.create_dict_iterator():
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
|
|
|
try:
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False)
|
|
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
|
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
|
for _ in data.create_dict_iterator():
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
|
|
|
try:
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False)
|
|
data = data.map(operations=exception_func, input_columns=["attr"], num_parallel_workers=1)
|
|
for _ in data.create_dict_iterator():
|
|
pass
|
|
assert False
|
|
except RuntimeError as e:
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
|
|
|
|
|
def test_celeba_sampler_exception():
|
|
"""
|
|
Test CelebA with bad sampler input
|
|
"""
|
|
logger.info("Test CelebA with bad sampler input")
|
|
try:
|
|
data = ds.CelebADataset(DATA_DIR, sampler="")
|
|
for _ in data.create_dict_iterator():
|
|
pass
|
|
assert False
|
|
except TypeError as e:
|
|
assert "Unsupported sampler object of type (<class 'str'>)" in str(e)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_celeba_dataset_label()
|
|
test_celeba_dataset_op()
|
|
test_celeba_dataset_ext()
|
|
test_celeba_dataset_distribute()
|
|
test_celeba_get_dataset_size()
|
|
test_celeba_dataset_exception_file_path()
|
|
test_celeba_sampler_exception()
|