From 59a714c654dbbab090cc8af0012a6b1fc1e4e3a1 Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Thu, 9 Apr 2020 15:22:33 -0400 Subject: [PATCH] Correct shuffle UT buffer_size > #dataset-row as valid --- .../data/dataset/golden/shuffle_05_result.npz | Bin 0 -> 1507 bytes tests/ut/python/dataset/test_shuffle.py | 39 +++++++++--------- 2 files changed, 20 insertions(+), 19 deletions(-) create mode 100644 tests/ut/data/dataset/golden/shuffle_05_result.npz diff --git a/tests/ut/data/dataset/golden/shuffle_05_result.npz b/tests/ut/data/dataset/golden/shuffle_05_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..27eb0a470d370fab9d4938e5fb9e3eb50d603bd1 GIT binary patch literal 1507 zcmbW1OH&hB6vw+00!c(c4LUmWtAJtD)+OVZ!lkD_Ooy{uH(5UT@6!HZoH}d``mL*e)pVH>6?PTC89M~ul9TG z?dO$%rm1g8i)dNL$qdDd<(jT_vQs6hpjvN}i*u=EU3;$m94*+Dyc3-;qA%v7F(bNK za;i?Ym?=30Te)Ym8x@=Um9=cyro9-C8N(w3G2<8GyH~jG&lk5g%e8pEZTT|wXz+S zZA(6CbAC()p+SEwLklu2;9t_o`$FzYCXf0b zj85?;sAfj14xJo(1@>V-8T$kd;2;_M1@@qWjEF!B0%UXvv|~3J2Lyr$k#SHUf-W)+ z3ACb(jKcz9>>{IEz(6nhaDp~R7~!TI{mSkI>AIpj#sBTIS#Pf0IggLD*Yq+`gFj^hdG1Z>hvct&~|1=1^6CB15y z{nj4_{h??xX-$6dHm3F0=SJi2X5^Y>M*scfZmwI_c4W#j`#${nxS2L>nE?+lW0^hw zyfm7TS(izJLF*8M$GZ9bVc^%k1_d2?VJqKqtMG(CUo_8FL7cCg&Qz(qkTN zq$&7FZ=;oT0WG9=5GK6~Kj|U@r1zM0pIJ)@I=N2HWoA84Z=Ry9Fzq4J(oB2Av<%Z8 zGcC)s9MkelD{Q}##IPB;`X;DeXyDERm#Z%CI5&yUDWg3FJ>gp=mzNJ$Lv?_;ApHX_g H`G6I{m literal 0 HcmV?d00001 diff --git a/tests/ut/python/dataset/test_shuffle.py b/tests/ut/python/dataset/test_shuffle.py index 2b7a251d2c..4a823c5fb7 100644 --- a/tests/ut/python/dataset/test_shuffle.py +++ b/tests/ut/python/dataset/test_shuffle.py @@ -98,6 +98,25 @@ def test_shuffle_04(): save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) +def test_shuffle_05(): + """ + Test shuffle: buffer_size > number-of-rows-in-dataset + """ + logger.info("test_shuffle_05") + # define parameters + buffer_size = 13 + seed = 1 + parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) + ds.config.set_seed(seed) + data1 = data1.shuffle(buffer_size=buffer_size) + + filename = "shuffle_05_result.npz" + save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + + def test_shuffle_exception_01(): """ Test shuffle exception: buffer_size<0 @@ -152,24 +171,6 @@ def test_shuffle_exception_03(): assert "buffer_size" in str(e) -def test_shuffle_exception_04(): - """ - Test shuffle exception: buffer_size > number-of-rows-in-dataset - """ - logger.info("test_shuffle_exception_04") - - # apply dataset operations - data1 = ds.TFRecordDataset(DATA_DIR) - ds.config.set_seed(1) - try: - data1 = data1.shuffle(buffer_size=13) - sum([1 for _ in data1]) - - except BaseException as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) - - def test_shuffle_exception_05(): """ Test shuffle exception: Missing mandatory buffer_size input parameter @@ -229,10 +230,10 @@ if __name__ == '__main__': test_shuffle_02() test_shuffle_03() test_shuffle_04() + test_shuffle_05() test_shuffle_exception_01() test_shuffle_exception_02() test_shuffle_exception_03() - test_shuffle_exception_04() test_shuffle_exception_05() test_shuffle_exception_06() test_shuffle_exception_07()