|
|
|
@ -27,7 +27,7 @@ def test_imagefolder_shardings(print_res=False):
|
|
|
|
|
res = []
|
|
|
|
|
for item in data1.create_dict_iterator(): # each data is a dictionary
|
|
|
|
|
res.append(item["label"].item())
|
|
|
|
|
if (print_res):
|
|
|
|
|
if print_res:
|
|
|
|
|
logger.info("labels of dataset: {}".format(res))
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
@ -39,12 +39,12 @@ def test_imagefolder_shardings(print_res=False):
|
|
|
|
|
assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) # 22 rows
|
|
|
|
|
assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]) # 22 rows
|
|
|
|
|
# total 22 in dataset rows because of class indexing which takes only 2 folders
|
|
|
|
|
assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6)
|
|
|
|
|
assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3)
|
|
|
|
|
assert len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6
|
|
|
|
|
assert len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3
|
|
|
|
|
# test with repeat
|
|
|
|
|
assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3)
|
|
|
|
|
assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5)
|
|
|
|
|
assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20)
|
|
|
|
|
assert len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_tfrecord_shardings1(print_res=False):
|
|
|
|
@ -176,8 +176,8 @@ def test_voc_shardings(print_res=False):
|
|
|
|
|
# then takes the first 2 bc num_samples = 2
|
|
|
|
|
assert (sharding_config(3, 2, 2, False, 4) == [2268, 607] * 4)
|
|
|
|
|
# test that each epoch, each shard_worker returns a different sample
|
|
|
|
|
assert (len(sharding_config(2, 0, None, True, 1)) == 5)
|
|
|
|
|
assert (len(set(sharding_config(11, 0, None, True, 10))) > 1)
|
|
|
|
|
assert len(sharding_config(2, 0, None, True, 1)) == 5
|
|
|
|
|
assert len(set(sharding_config(11, 0, None, True, 10))) > 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cifar10_shardings(print_res=False):
|
|
|
|
@ -196,8 +196,8 @@ def test_cifar10_shardings(print_res=False):
|
|
|
|
|
|
|
|
|
|
# 60000 rows in total. CIFAR reads everything in memory which would make each test case very slow
|
|
|
|
|
# therefore, only 2 test cases for now.
|
|
|
|
|
assert (sharding_config(10000, 9999, 7, False, 1) == [9])
|
|
|
|
|
assert (sharding_config(10000, 0, 4, False, 3) == [0, 0, 0])
|
|
|
|
|
assert sharding_config(10000, 9999, 7, False, 1) == [9]
|
|
|
|
|
assert sharding_config(10000, 0, 4, False, 3) == [0, 0, 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cifar100_shardings(print_res=False):
|
|
|
|
|