|
|
|
@ -158,22 +158,14 @@ class Dataset:
|
|
|
|
|
if len(self.parent) > 1:
|
|
|
|
|
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
|
|
|
|
|
ir_children = [d.parse_tree() for d in self.children]
|
|
|
|
|
# Bootstrap can only be performed on a copy of the original dataset node.
|
|
|
|
|
# Bootstrap on original dataset node will make all iterators share the same process pool
|
|
|
|
|
self.iterator_bootstrap()
|
|
|
|
|
ir_node = self.parse(ir_children)
|
|
|
|
|
return self._alter_node(ir_node)
|
|
|
|
|
return ir_node
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _alter_node(node):
|
|
|
|
|
"""
|
|
|
|
|
Internal method to add process pool to copied map node.
|
|
|
|
|
Returns:
|
|
|
|
|
DatasetNode. The altered node.
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(node, MapDataset):
|
|
|
|
|
if node.python_multiprocessing:
|
|
|
|
|
# Bootstrap can only be performed on a copy of the original dataset node.
|
|
|
|
|
# Bootstrap on original dataset node will make all iterators share the same process pool
|
|
|
|
|
node.iterator_bootstrap()
|
|
|
|
|
return node
|
|
|
|
|
def iterator_bootstrap(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _noop_mode():
|
|
|
|
@ -272,7 +264,7 @@ class Dataset:
|
|
|
|
|
|
|
|
|
|
@check_batch
|
|
|
|
|
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
|
|
|
|
|
input_columns=None, output_columns=None, column_order=None, pad_info=None):
|
|
|
|
|
input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False):
|
|
|
|
|
"""
|
|
|
|
|
Combine batch_size number of consecutive rows into batches.
|
|
|
|
|
|
|
|
|
@ -312,6 +304,8 @@ class Dataset:
|
|
|
|
|
same).
|
|
|
|
|
pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
|
|
|
|
|
would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
|
|
|
|
|
python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multiple worker
|
|
|
|
|
processes. This option could be beneficial if the function is computational heavy (default=False).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
BatchDataset, dataset batched.
|
|
|
|
@ -339,7 +333,7 @@ class Dataset:
|
|
|
|
|
>>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
|
|
|
|
|
"""
|
|
|
|
|
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
|
|
|
|
|
output_columns, column_order, pad_info)
|
|
|
|
|
output_columns, column_order, pad_info, python_multiprocessing)
|
|
|
|
|
|
|
|
|
|
@check_sync_wait
|
|
|
|
|
def sync_wait(self, condition_name, num_batch=1, callback=None):
|
|
|
|
@ -1835,7 +1829,8 @@ class BatchDataset(Dataset):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None,
|
|
|
|
|
per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None):
|
|
|
|
|
per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None,
|
|
|
|
|
python_multiprocessing=False):
|
|
|
|
|
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
|
|
|
|
|
|
|
|
|
|
if BatchDataset._is_ancestor_of_repeat(input_dataset):
|
|
|
|
@ -1858,6 +1853,10 @@ class BatchDataset(Dataset):
|
|
|
|
|
self.pad = bool(pad_info is not None)
|
|
|
|
|
self.pad_info = replace_none(pad_info, dict())
|
|
|
|
|
|
|
|
|
|
self.python_multiprocessing = python_multiprocessing
|
|
|
|
|
self.process_pool = None
|
|
|
|
|
self.hook = None
|
|
|
|
|
|
|
|
|
|
def parse(self, children=None):
|
|
|
|
|
return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad,
|
|
|
|
|
self.input_columns, self.output_columns,
|
|
|
|
@ -1923,9 +1922,32 @@ class BatchDataset(Dataset):
|
|
|
|
|
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
|
|
|
|
|
new_op.column_order = copy.deepcopy(self.column_order, memodict)
|
|
|
|
|
new_op.pad = self.pad
|
|
|
|
|
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
|
|
|
|
|
new_op.hook = copy.deepcopy(self.hook, memodict)
|
|
|
|
|
new_op.pad_info = copy.deepcopy(self.pad_info, memodict)
|
|
|
|
|
return new_op
|
|
|
|
|
|
|
|
|
|
# Iterator bootstrap will be called on iterator construction.
|
|
|
|
|
# A deep copy of Dataset object is created prior of iterator_bootstrap.
|
|
|
|
|
# This method will create per iterator process pool and bind pyfunc execution to the pool.
|
|
|
|
|
def iterator_bootstrap(self):
|
|
|
|
|
"""
|
|
|
|
|
Per iterator bootstrap callback.
|
|
|
|
|
"""
|
|
|
|
|
if self.python_multiprocessing:
|
|
|
|
|
# Construct pool with the callable list
|
|
|
|
|
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
|
|
|
|
|
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
|
|
|
|
|
initializer=_pyfunc_worker_init,
|
|
|
|
|
initargs=([self.per_batch_map],))
|
|
|
|
|
idx = 0
|
|
|
|
|
# Wrap per_batch_map into _PythonCallable
|
|
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None:
|
|
|
|
|
self.process_pool.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchInfo(cde.CBatchInfo):
|
|
|
|
|
"""
|
|
|
|
@ -2352,7 +2374,6 @@ class MapDataset(Dataset):
|
|
|
|
|
# CPP ops remain the same
|
|
|
|
|
iter_specific_operations.append(op)
|
|
|
|
|
self.operations = iter_specific_operations
|
|
|
|
|
self.hook = _ExceptHookHandler(self.process_pool)
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None:
|
|
|
|
|