|
|
|
@ -107,6 +107,7 @@ def test_two_sync():
|
|
|
|
|
if count % 2 == 0:
|
|
|
|
|
dataset.sync_update(condition_name="every 2 batches")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sync_epoch():
|
|
|
|
|
"""
|
|
|
|
|
Test sync wait with epochs: test sync with epochs in dataset pipeline
|
|
|
|
@ -130,6 +131,34 @@ def test_sync_epoch():
|
|
|
|
|
dataset.sync_update(condition_name="policy", data=data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_multiple_iterators():
|
|
|
|
|
"""
|
|
|
|
|
Test sync wait with multiple iterators: will start multiple
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_sync_epoch")
|
|
|
|
|
batch_size = 30
|
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
|
|
|
|
|
|
aug = Augment(0)
|
|
|
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
|
|
|
|
# 2nd dataset
|
|
|
|
|
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
|
|
|
|
|
|
aug = Augment(0)
|
|
|
|
|
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
|
|
|
|
|
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
|
dataset2 = dataset2.batch(batch_size, drop_remainder=True)
|
|
|
|
|
|
|
|
|
|
for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
|
|
|
|
|
assert (item1["input"][0] == item2["input"][0])
|
|
|
|
|
data1 = {"loss": item1["input"][0]}
|
|
|
|
|
data2 = {"loss": item2["input"][0]}
|
|
|
|
|
dataset.sync_update(condition_name="policy", data=data1)
|
|
|
|
|
dataset2.sync_update(condition_name="policy", data=data2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sync_exception_01():
|
|
|
|
|
"""
|
|
|
|
|
Test sync: with shuffle in sync mode
|
|
|
|
@ -179,4 +208,5 @@ if __name__ == "__main__":
|
|
|
|
|
test_two_sync()
|
|
|
|
|
test_sync_exception_01()
|
|
|
|
|
test_sync_exception_02()
|
|
|
|
|
test_sync_epoch()
|
|
|
|
|
test_sync_epoch()
|
|
|
|
|
test_multiple_iterators()
|
|
|
|
|