|
|
|
@ -37,6 +37,7 @@ def test_batch_01():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 6
|
|
|
|
|
filename = "batch_01_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -56,6 +57,7 @@ def test_batch_02():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 2
|
|
|
|
|
filename = "batch_02_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -75,6 +77,7 @@ def test_batch_03():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 4
|
|
|
|
|
filename = "batch_03_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -94,6 +97,7 @@ def test_batch_04():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 2
|
|
|
|
|
filename = "batch_04_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -111,6 +115,7 @@ def test_batch_05():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 12
|
|
|
|
|
filename = "batch_05_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -130,6 +135,7 @@ def test_batch_06():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 1
|
|
|
|
|
filename = "batch_06_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -152,6 +158,7 @@ def test_batch_07():
|
|
|
|
|
data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder,
|
|
|
|
|
batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 3
|
|
|
|
|
filename = "batch_07_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -171,6 +178,7 @@ def test_batch_08():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 2
|
|
|
|
|
filename = "batch_08_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -190,6 +198,7 @@ def test_batch_09():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 1
|
|
|
|
|
filename = "batch_09_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -209,6 +218,7 @@ def test_batch_10():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 0
|
|
|
|
|
filename = "batch_10_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
@ -228,10 +238,30 @@ def test_batch_11():
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, schema_file)
|
|
|
|
|
data1 = data1.batch(batch_size)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 1
|
|
|
|
|
filename = "batch_11_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_batch_12():
|
|
|
|
|
"""
|
|
|
|
|
Test batch: batch_size boolean value True, treated as valid value 1
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_batch_12")
|
|
|
|
|
# define parameters
|
|
|
|
|
batch_size = True
|
|
|
|
|
parameters = {"params": {'batch_size': batch_size}}
|
|
|
|
|
|
|
|
|
|
# apply dataset operations
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
data1 = data1.batch(batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data1]) == 12
|
|
|
|
|
filename = "batch_12_result.npz"
|
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_batch_exception_01():
|
|
|
|
|
"""
|
|
|
|
|
Test batch exception: num_parallel_workers=0
|
|
|
|
@ -302,7 +332,7 @@ def test_batch_exception_04():
|
|
|
|
|
|
|
|
|
|
def test_batch_exception_05():
|
|
|
|
|
"""
|
|
|
|
|
Test batch exception: batch_size wrong type, boolean value False
|
|
|
|
|
Test batch exception: batch_size boolean value False, treated as invalid value 0
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_batch_exception_05")
|
|
|
|
|
|
|
|
|
@ -317,23 +347,6 @@ def test_batch_exception_05():
|
|
|
|
|
assert "batch_size" in str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def skip_test_batch_exception_06():
|
|
|
|
|
"""
|
|
|
|
|
Test batch exception: batch_size wrong type, boolean value True
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_batch_exception_06")
|
|
|
|
|
|
|
|
|
|
# apply dataset operations
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
|
|
|
|
try:
|
|
|
|
|
data1 = data1.batch(batch_size=True)
|
|
|
|
|
sum([1 for _ in data1])
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
|
|
|
|
assert "batch_size" in str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_batch_exception_07():
|
|
|
|
|
"""
|
|
|
|
|
Test batch exception: drop_remainder wrong type
|
|
|
|
@ -473,12 +486,12 @@ if __name__ == '__main__':
|
|
|
|
|
test_batch_09()
|
|
|
|
|
test_batch_10()
|
|
|
|
|
test_batch_11()
|
|
|
|
|
test_batch_12()
|
|
|
|
|
test_batch_exception_01()
|
|
|
|
|
test_batch_exception_02()
|
|
|
|
|
test_batch_exception_03()
|
|
|
|
|
test_batch_exception_04()
|
|
|
|
|
test_batch_exception_05()
|
|
|
|
|
skip_test_batch_exception_06()
|
|
|
|
|
test_batch_exception_07()
|
|
|
|
|
test_batch_exception_08()
|
|
|
|
|
test_batch_exception_09()
|
|
|
|
|