|
|
|
@ -98,6 +98,25 @@ def test_shuffle_04():
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_shuffle_05():
|
|
|
|
|
"""
|
|
|
|
|
Test shuffle: buffer_size > number-of-rows-in-dataset
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_shuffle_05")
|
|
|
|
|
# define parameters
|
|
|
|
|
buffer_size = 13
|
|
|
|
|
seed = 1
|
|
|
|
|
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
|
|
|
|
|
|
|
|
|
|
# apply dataset operations
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
ds.config.set_seed(seed)
|
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
|
|
|
|
|
|
filename = "shuffle_05_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_shuffle_exception_01():
|
|
|
|
|
"""
|
|
|
|
|
Test shuffle exception: buffer_size<0
|
|
|
|
@ -152,24 +171,6 @@ def test_shuffle_exception_03():
|
|
|
|
|
assert "buffer_size" in str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_shuffle_exception_04():
|
|
|
|
|
"""
|
|
|
|
|
Test shuffle exception: buffer_size > number-of-rows-in-dataset
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_shuffle_exception_04")
|
|
|
|
|
|
|
|
|
|
# apply dataset operations
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR)
|
|
|
|
|
ds.config.set_seed(1)
|
|
|
|
|
try:
|
|
|
|
|
data1 = data1.shuffle(buffer_size=13)
|
|
|
|
|
sum([1 for _ in data1])
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
|
|
|
|
assert "buffer_size" in str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_shuffle_exception_05():
|
|
|
|
|
"""
|
|
|
|
|
Test shuffle exception: Missing mandatory buffer_size input parameter
|
|
|
|
@ -229,10 +230,10 @@ if __name__ == '__main__':
|
|
|
|
|
test_shuffle_02()
|
|
|
|
|
test_shuffle_03()
|
|
|
|
|
test_shuffle_04()
|
|
|
|
|
test_shuffle_05()
|
|
|
|
|
test_shuffle_exception_01()
|
|
|
|
|
test_shuffle_exception_02()
|
|
|
|
|
test_shuffle_exception_03()
|
|
|
|
|
test_shuffle_exception_04()
|
|
|
|
|
test_shuffle_exception_05()
|
|
|
|
|
test_shuffle_exception_06()
|
|
|
|
|
test_shuffle_exception_07()
|
|
|
|
|