|
|
|
@ -1055,6 +1055,11 @@ class Dataset:
|
|
|
|
|
return self.input[0].get_sync_notifiers()
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
def disable_sync(self):
|
|
|
|
|
if self.input:
|
|
|
|
|
return self.input[0].disable_sync()
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
def is_sync(self):
|
|
|
|
|
if self.input:
|
|
|
|
|
return self.input[0].is_sync()
|
|
|
|
@ -1062,16 +1067,23 @@ class Dataset:
|
|
|
|
|
|
|
|
|
|
def sync_update(self, condition_name, num_batch=None, data=None):
|
|
|
|
|
"""
|
|
|
|
|
Release a blocking condition and triger callback with given data.
|
|
|
|
|
Release a blocking condition and trigger callback with given data.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
condition_name (str): The condition name that is used to toggle sending next row.
|
|
|
|
|
num_batch (int or None): The number of batches(rows) that are released.
|
|
|
|
|
When num_batch is None, it will default to the number specified by the sync_wait operator.
|
|
|
|
|
data (dict or None): The data passed to the callback.
|
|
|
|
|
"""
|
|
|
|
|
When num_batch is None, it will default to the number specified by the
|
|
|
|
|
sync_wait operator (default=None).
|
|
|
|
|
data (dict or None): The data passed to the callback (default=None).
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(num_batch, int) and num_batch <= 0:
|
|
|
|
|
# throwing exception, disable all sync_wait in pipeline
|
|
|
|
|
self.disable_sync()
|
|
|
|
|
raise RuntimeError("Sync_update batch size can only be positive, got : {}".format(num_batch))
|
|
|
|
|
notifiers_dict = self.get_sync_notifiers()
|
|
|
|
|
if condition_name not in notifiers_dict:
|
|
|
|
|
# throwing exception, disable all sync_wait in pipeline
|
|
|
|
|
self.disable_sync()
|
|
|
|
|
raise RuntimeError("Condition name not found")
|
|
|
|
|
if num_batch is not None:
|
|
|
|
|
num_batch *= self.get_batch_size()
|
|
|
|
@ -1433,7 +1445,6 @@ class BatchDataset(DatasetOp):
|
|
|
|
|
for input_dataset in dataset.input:
|
|
|
|
|
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchInfo(CBatchInfo):
|
|
|
|
|
"""
|
|
|
|
|
The information object associates with the current batch of tensors.
|
|
|
|
@ -1466,10 +1477,13 @@ class BlockReleasePair:
|
|
|
|
|
callback (function): The callback funciton that will be called when release is called.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, init_release_rows, callback=None):
|
|
|
|
|
if isinstance(init_release_rows, int) and init_release_rows <= 0:
|
|
|
|
|
raise ValueError("release_rows need to be greater than 0.")
|
|
|
|
|
self.row_count = -init_release_rows
|
|
|
|
|
self.cv = threading.Condition()
|
|
|
|
|
self.callback = callback
|
|
|
|
|
self.default_rows = init_release_rows
|
|
|
|
|
self.disable = False
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
|
if id(self) in memodict:
|
|
|
|
@ -1485,13 +1499,18 @@ class BlockReleasePair:
|
|
|
|
|
self.cv.notify_all()
|
|
|
|
|
|
|
|
|
|
def update_batched_size(self, batch_size):
|
|
|
|
|
# sanity check
|
|
|
|
|
if isinstance(batch_size, int) and batch_size <= 0:
|
|
|
|
|
raise ValueError("batch_size need to be greater than 0.")
|
|
|
|
|
|
|
|
|
|
# should only use before the pipeline creates
|
|
|
|
|
self.row_count *= batch_size
|
|
|
|
|
self.default_rows *= batch_size
|
|
|
|
|
|
|
|
|
|
def block_func(self):
|
|
|
|
|
with self.cv:
|
|
|
|
|
self.cv.wait_for(lambda: self.row_count < 0)
|
|
|
|
|
# if disable is true, the always evaluate to true
|
|
|
|
|
self.cv.wait_for(lambda: (self.row_count < 0 or self.disable))
|
|
|
|
|
self.row_count += 1
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
@ -1504,6 +1523,12 @@ class BlockReleasePair:
|
|
|
|
|
self.callback(data)
|
|
|
|
|
self.cv.notify_all()
|
|
|
|
|
|
|
|
|
|
def disable_lock(self):
|
|
|
|
|
with self.cv:
|
|
|
|
|
self.disable = True
|
|
|
|
|
self.cv.notify_all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SyncWaitDataset(DatasetOp):
|
|
|
|
|
"""
|
|
|
|
|
The result of adding a blocking condition to the input Dataset.
|
|
|
|
@ -1524,6 +1549,9 @@ class SyncWaitDataset(DatasetOp):
|
|
|
|
|
input_dataset.output.append(self)
|
|
|
|
|
# set to the default value, waiting for the batch to update it
|
|
|
|
|
self._condition_name = condition_name
|
|
|
|
|
if isinstance(num_batch, int) and num_batch <= 0:
|
|
|
|
|
raise ValueError("num_batch need to be greater than 0.")
|
|
|
|
|
|
|
|
|
|
self._pair = BlockReleasePair(num_batch, callback)
|
|
|
|
|
if self._condition_name in self.input[0].get_sync_notifiers():
|
|
|
|
|
raise RuntimeError("Condition name is already in use")
|
|
|
|
@ -1543,8 +1571,14 @@ class SyncWaitDataset(DatasetOp):
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
def update_sync_batch_size(self, batch_size):
|
|
|
|
|
if isinstance(batch_size, int) and batch_size <= 0:
|
|
|
|
|
raise ValueError("num_batch need to be greater than 0.")
|
|
|
|
|
self._pair.update_batched_size(batch_size)
|
|
|
|
|
|
|
|
|
|
def disable_sync(self):
|
|
|
|
|
logger.info("Disabling Sync")
|
|
|
|
|
self._pair.disable_lock()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _is_ancestor_of_batch(dataset):
|
|
|
|
|
"""
|
|
|
|
|