You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/ut/python/dataset/test_epoch_ctrl.py

715 lines
22 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Epoch Control op in DE
"""
import itertools
import numpy as np
import pytest
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def diff_mse(in1, in2):
"""
diff_mse
"""
mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
return mse * 100
def test_cifar10():
"""
dataset parameter
"""
logger.info("Test dataset parameter")
data_dir_10 = "../data/dataset/testCifar10Data"
num_repeat = 2
batch_size = 32
limit_dataset = 100
# apply dataset operations
data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
data1 = data1.repeat(num_repeat)
data1 = data1.batch(batch_size, True)
num_epoch = 5
# iter1 will always assume there is a next epoch and never shutdown.
iter1 = data1.create_tuple_iterator()
epoch_count = 0
sample_count = 0
for _ in range(num_epoch):
row_count = 0
for _ in iter1:
# in this example, each dictionary has keys "image" and "label"
row_count += 1
assert row_count == int(limit_dataset * num_repeat / batch_size)
logger.debug("row_count: ", row_count)
epoch_count += 1
sample_count += row_count
assert epoch_count == num_epoch
logger.debug("total epochs: ", epoch_count)
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
logger.debug("total sample: ", sample_count)
def test_decode_op():
"""
Test Decode op
"""
logger.info("test_decode_op")
# Decode with rgb format set to True
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
# Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
num_epoch = 5
# iter1 will always assume there is a next epoch and never shutdown.
iter1 = data1.create_dict_iterator(output_numpy=True)
# iter 2 will stop and shutdown pipeline after num_epoch
iter2 = data2.create_dict_iterator(num_epoch, output_numpy=True)
for _ in range(num_epoch):
i = 0
for item1, item2 in itertools.zip_longest(iter1, iter2):
actual = item1["image"]
expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR)
expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB)
assert actual.shape == expected.shape
diff = actual - expected
mse = np.sum(np.power(diff, 2))
assert mse == 0
i = i + 1
assert i == 3
# Users have the option to manually stop the iterator, or rely on garbage collector.
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute '_runtime_context'" in str(info.value)
with pytest.raises(RuntimeError) as info:
iter2.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Generate 1d int numpy array from 0 - 63
def generator_1d():
"""
generator
"""
for i in range(64):
yield (np.array([i]),)
def test_generator_dict_0():
"""
test generator dict 0
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
i = 0
# create the iterator inside the loop declaration
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"], golden)
i = i + 1
def test_generator_dict_1():
"""
test generator dict 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
for _ in range(10):
i = 0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"], golden)
i = i + 1
assert i == 64
def test_generator_dict_2():
"""
test generator dict 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"].asnumpy(), golden)
i = i + 1
assert i == 64
# iter1 is still alive and running.
item1 = iter1.__next__()
assert item1
# rely on garbage collector to destroy iter1
def test_generator_dict_3():
"""
test generator dict 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"].asnumpy(), golden)
i = i + 1
assert i == 64
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_dict_4():
"""
test generator dict 4
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator(num_epochs=10)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"].asnumpy(), golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_1():
"""
test generator dict 4_1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
# epoch ctrl op will not be injected if num_epochs is 1.
iter1 = data1.create_dict_iterator(num_epochs=1, output_numpy=True)
for _ in range(1):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_2():
"""
test generator dict 4_2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
# repeat will not be injected when num repeat is 1.
data1 = data1.repeat(1)
# epoch ctrl op will not be injected if num_epochs is 1.
iter1 = data1.create_dict_iterator(num_epochs=1, output_numpy=True)
for _ in range(1):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_5():
"""
test generator dict 5
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator(num_epochs=11, output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"], golden)
i = i + 1
assert i == 64
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item["data"], golden)
i = i + 1
assert i == 64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test tuple iterator
def test_generator_tuple_0():
"""
test generator tuple 0
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
i = 0
# create the iterator inside the loop declaration
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
def test_generator_tuple_1():
"""
test generator tuple 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
for _ in range(10):
i = 0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64
def test_generator_tuple_2():
"""
test generator tuple 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64
# iter1 is still alive and running.
item1 = iter1.__next__()
assert item1
# rely on garbage collector to destroy iter1
def test_generator_tuple_3():
"""
test generator tuple 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_tuple_4():
"""
test generator tuple 4
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(num_epochs=10, output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_tuple_5():
"""
test generator tuple 5
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test with repeat
def test_generator_tuple_repeat_1():
"""
test generator tuple repeat 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test with repeat
def test_generator_tuple_repeat_repeat_1():
"""
test generator tuple repeat repeat 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_tuple_repeat_repeat_2():
"""
test generator tuple repeat repeat 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator(output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_tuple_repeat_repeat_3():
"""
test generator tuple repeat repeat 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator(output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
for _ in range(5):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# rely on garbage collector to destroy iter1
def test_generator_tuple_infinite_repeat_repeat_1():
"""
test generator tuple infinite repeat repeat 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat()
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
if i == 100:
break
# rely on garbage collector to destroy iter1
def test_generator_tuple_infinite_repeat_repeat_2():
"""
test generator tuple infinite repeat repeat 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(3)
data1 = data1.repeat()
iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
if i == 100:
break
# rely on garbage collector to destroy iter1
def test_generator_tuple_infinite_repeat_repeat_3():
"""
test generator tuple infinite repeat repeat 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat()
data1 = data1.repeat()
iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
if i == 100:
break
# rely on garbage collector to destroy iter1
def test_generator_tuple_infinite_repeat_repeat_4():
"""
test generator tuple infinite repeat repeat 4
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat()
data1 = data1.repeat()
iter1 = data1.create_tuple_iterator(output_numpy=True)
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
if i == 100:
break
# rely on garbage collector to destroy iter1
def test_generator_reusedataset():
"""
test generator reusedataset
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
iter1 = data1.create_tuple_iterator(output_numpy=True)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator(output_numpy=True)
for _ in range(5):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
np.testing.assert_array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
data1 = data1.batch(2)
iter1 = data1.create_dict_iterator(output_numpy=True)
for _ in range(5):
i = 0
sample = 0
for item in iter1: # each data is a dictionary
golden = np.array([[i % 64], [(i + 1) % 64]])
np.testing.assert_array_equal(item["data"], golden)
i = i + 2
sample = sample + 1
assert sample == 64 * 3
# rely on garbage collector to destroy iter1