You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/ut/python/dataset/test_tfreader_op.py

315 lines
12 KiB

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test TFRecordDataset Ops
"""
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore import log as logger
from util import save_and_check_dict
FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
GENERATE_GOLDEN = False
def test_tfrecord_shape():
logger.info("test_tfrecord_shape")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
ds1 = ds1.batch(2)
for data in ds1.create_dict_iterator():
logger.info(data)
output_shape = ds1.output_shapes()
assert len(output_shape[-1]) == 1
def test_tfrecord_read_all_dataset():
logger.info("test_tfrecord_read_all_dataset")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 12
count = 0
for _ in ds1.create_tuple_iterator():
count += 1
assert count == 12
def test_tfrecord_num_samples():
logger.info("test_tfrecord_num_samples")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
assert ds1.get_dataset_size() == 8
count = 0
for _ in ds1.create_dict_iterator():
count += 1
assert count == 8
def test_tfrecord_num_samples2():
logger.info("test_tfrecord_num_samples2")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 7
count = 0
for _ in ds1.create_dict_iterator():
count += 1
assert count == 7
def test_tfrecord_shape2():
logger.info("test_tfrecord_shape2")
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
ds1 = ds1.batch(2)
output_shape = ds1.output_shapes()
assert len(output_shape[-1]) == 2
def test_tfrecord_files_basic():
logger.info("test_tfrecord_files_basic")
data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
filename = "tfrecord_files_basic.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_tfrecord_no_schema():
logger.info("test_tfrecord_no_schema")
data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
filename = "tfrecord_no_schema.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_tfrecord_pad():
logger.info("test_tfrecord_pad")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
filename = "tfrecord_pad_bytes10.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
def test_tfrecord_read_files():
logger.info("test_tfrecord_read_files")
pattern = DATASET_ROOT + "/test.data"
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 12
pattern = DATASET_ROOT + "/test2.data"
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 12
pattern = DATASET_ROOT + "/*.data"
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 24
pattern = DATASET_ROOT + "/*.data"
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=3, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 3
data = ds.TFRecordDataset([DATASET_ROOT + "/test.data", DATASET_ROOT + "/test2.data"],
SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 24
def test_tfrecord_multi_files():
logger.info("test_tfrecord_multi_files")
data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False)
data1 = data1.repeat(1)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter += 1
assert num_iter == 12
def test_tfrecord_schema():
logger.info("test_tfrecord_schema")
schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
schema.add_column('col_float', de_type=mstype.float32, shape=[1])
schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
data1 = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
for d1, d2 in zip(data1, data2):
for t1, t2 in zip(d1, d2):
assert np.array_equal(t1, t2)
def test_tfrecord_shuffle():
logger.info("test_tfrecord_shuffle")
ds.config.set_seed(1)
data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
data2 = data2.shuffle(10000)
for d1, d2 in zip(data1, data2):
for t1, t2 in zip(d1, d2):
assert np.array_equal(t1, t2)
def test_tfrecord_shard():
logger.info("test_tfrecord_shard")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
def get_res(shard_id, num_repeats):
data1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=shard_id, num_samples=3,
shuffle=ds.Shuffle.FILES)
data1 = data1.repeat(num_repeats)
res = list()
for item in data1.create_dict_iterator():
res.append(item["scalars"][0])
return res
# get separate results from two workers. the 2 results need to satisfy 2 criteria
# 1. two workers always give different results in same epoch (e.g. wrkr1:f1&f3, wrkr2:f2&f4 in one epoch)
# 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4)
worker1_res = get_res(0, 16)
worker2_res = get_res(1, 16)
# Confirm each worker gets 3x16=48 rows
assert len(worker1_res) == 48
assert len(worker1_res) == len(worker2_res)
# check criteria 1
for i, _ in enumerate(worker1_res):
assert worker1_res[i] != worker2_res[i]
# check criteria 2
assert set(worker2_res) == set(worker1_res)
def test_tfrecord_shard_equal_rows():
logger.info("test_tfrecord_shard_equal_rows")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
def get_res(num_shards, shard_id, num_repeats):
ds1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, shard_equal_rows=True)
ds1 = ds1.repeat(num_repeats)
res = list()
for data in ds1.create_dict_iterator():
res.append(data["scalars"][0])
return res
worker1_res = get_res(3, 0, 2)
worker2_res = get_res(3, 1, 2)
worker3_res = get_res(3, 2, 2)
# check criteria 1
for i, _ in enumerate(worker1_res):
assert worker1_res[i] != worker2_res[i]
assert worker2_res[i] != worker3_res[i]
# Confirm each worker gets same number of rows
assert len(worker1_res) == 28
assert len(worker1_res) == len(worker2_res)
assert len(worker2_res) == len(worker3_res)
worker4_res = get_res(1, 0, 1)
assert len(worker4_res) == 40
def test_tfrecord_no_schema_columns_list():
logger.info("test_tfrecord_no_schema_columns_list")
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next()
assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info:
_ = row["col_sint32"]
assert "col_sint32" in str(info.value)
def test_tfrecord_schema_columns_list():
logger.info("test_tfrecord_schema_columns_list")
schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
schema.add_column('col_float', de_type=mstype.float32, shape=[1])
schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next()
assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info:
_ = row["col_sint32"]
assert "col_sint32" in str(info.value)
def test_tfrecord_invalid_files():
logger.info("test_tfrecord_invalid_files")
valid_file = "../data/dataset/testTFTestAllTypes/test.data"
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
files = [invalid_file, valid_file, SCHEMA_FILE]
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
with pytest.raises(RuntimeError) as info:
_ = data.create_dict_iterator().get_next()
assert "cannot be opened" in str(info.value)
assert "not valid tfrecord files" in str(info.value)
assert valid_file not in str(info.value)
assert invalid_file in str(info.value)
assert SCHEMA_FILE in str(info.value)
nonexistent_file = "this/file/does/not/exist"
files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]
with pytest.raises(ValueError) as info:
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert "did not match any files" in str(info.value)
assert valid_file not in str(info.value)
assert invalid_file not in str(info.value)
assert SCHEMA_FILE not in str(info.value)
assert nonexistent_file in str(info.value)
if __name__ == '__main__':
test_tfrecord_shape()
test_tfrecord_read_all_dataset()
test_tfrecord_num_samples()
test_tfrecord_num_samples2()
test_tfrecord_shape2()
test_tfrecord_files_basic()
test_tfrecord_no_schema()
test_tfrecord_pad()
test_tfrecord_read_files()
test_tfrecord_multi_files()
test_tfrecord_schema()
test_tfrecord_shuffle()
test_tfrecord_shard()
test_tfrecord_shard_equal_rows()
test_tfrecord_no_schema_columns_list()
test_tfrecord_schema_columns_list()
test_tfrecord_invalid_files()