Added example for multiple iterator

Added new testcase for multi iterator

Addressing review

Fixed typo
pull/671/head
eric 5 years ago
parent ebc3f12b21
commit 2d115cd04e

@ -65,8 +65,8 @@ Status BarrierOp::operator()() {
TaskManager::FindMe()->Post();
// create child iterator, right now this barrier is a pipeline operator
int32_t worker_id = 0;
int32_t child_idx = 0;
const int32_t worker_id = 0;
const int32_t child_idx = 0;
child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
// Loop until eof is true

@ -920,7 +920,7 @@ class Dataset:
def sync_update(self, condition_name, num_batch=None, data=None):
"""
condition_name (str): The condition name that is used to toggle sending next row
step_size (int or None): The number of steps(rows) that are released
num_batch (int or None): The number of batches(rows) that are released
when pass_rows is None, will update the same number as sync_wait specified
data (dict or None): The data passed to the callback
"""

@ -107,6 +107,7 @@ def test_two_sync():
if count % 2 == 0:
dataset.sync_update(condition_name="every 2 batches")
def test_sync_epoch():
"""
Test sync wait with epochs: test sync with epochs in dataset pipeline
@ -130,6 +131,34 @@ def test_sync_epoch():
dataset.sync_update(condition_name="policy", data=data)
def test_multiple_iterators():
"""
Test sync wait with multiple iterators: will start multiple
"""
logger.info("test_sync_epoch")
batch_size = 30
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True)
# 2nd dataset
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
dataset2 = dataset2.batch(batch_size, drop_remainder=True)
for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
assert (item1["input"][0] == item2["input"][0])
data1 = {"loss": item1["input"][0]}
data2 = {"loss": item2["input"][0]}
dataset.sync_update(condition_name="policy", data=data1)
dataset2.sync_update(condition_name="policy", data=data2)
def test_sync_exception_01():
"""
Test sync: with shuffle in sync mode
@ -179,4 +208,5 @@ if __name__ == "__main__":
test_two_sync()
test_sync_exception_01()
test_sync_exception_02()
test_sync_epoch()
test_sync_epoch()
test_multiple_iterators()

Loading…
Cancel
Save