|
|
|
@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples():
|
|
|
|
|
count += 1
|
|
|
|
|
assert count == 2
|
|
|
|
|
|
|
|
|
|
def test_celeba_padded():
|
|
|
|
|
data = ds.CelebADataset("../data/dataset/testCelebAData/")
|
|
|
|
|
|
|
|
|
|
padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
|
|
|
|
|
padded_ds = ds.PaddedDataset(padded_samples)
|
|
|
|
|
data = data + padded_ds
|
|
|
|
|
dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
|
|
|
|
|
data.use_sampler(dis_sampler)
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
|
|
|
|
|
count = 0
|
|
|
|
|
for _ in data.create_dict_iterator():
|
|
|
|
|
count = count + 1
|
|
|
|
|
assert count == 2
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_TFRecord_Padded()
|
|
|
|
|
test_GeneratorDataSet_Padded()
|
|
|
|
|