|
|
|
@ -1889,6 +1889,26 @@ class BatchDataset(Dataset):
|
|
|
|
|
for input_dataset in dataset.children:
|
|
|
|
|
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
|
if id(self) in memodict:
|
|
|
|
|
return memodict[id(self)]
|
|
|
|
|
cls = self.__class__
|
|
|
|
|
new_op = cls.__new__(cls)
|
|
|
|
|
memodict[id(self)] = new_op
|
|
|
|
|
new_op.children = copy.deepcopy(self.children, memodict)
|
|
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict)
|
|
|
|
|
new_op.num_parallel_workers = self.num_parallel_workers
|
|
|
|
|
new_op.batch_size = self.batch_size
|
|
|
|
|
new_op.batch_size_func = self.batch_size_func
|
|
|
|
|
new_op.drop_remainder = self.drop_remainder
|
|
|
|
|
new_op.per_batch_map = self.per_batch_map
|
|
|
|
|
new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
|
|
|
|
|
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
|
|
|
|
|
new_op.column_order = copy.deepcopy(self.column_order, memodict)
|
|
|
|
|
new_op.pad = self.pad
|
|
|
|
|
new_op.pad_info = copy.deepcopy(self.pad_info, memodict)
|
|
|
|
|
return new_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchInfo(cde.CBatchInfo):
|
|
|
|
|
"""
|
|
|
|
@ -2753,6 +2773,22 @@ class TransferDataset(Dataset):
|
|
|
|
|
if self._to_device is not None:
|
|
|
|
|
self._to_device.release()
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
|
if id(self) in memodict:
|
|
|
|
|
return memodict[id(self)]
|
|
|
|
|
cls = self.__class__
|
|
|
|
|
new_op = cls.__new__(cls)
|
|
|
|
|
memodict[id(self)] = new_op
|
|
|
|
|
new_op.children = copy.deepcopy(self.children, memodict)
|
|
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict)
|
|
|
|
|
new_op.num_parallel_workers = self.num_parallel_workers
|
|
|
|
|
new_op.queue_name = self.queue_name
|
|
|
|
|
new_op.device_type = self.device_type
|
|
|
|
|
new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212
|
|
|
|
|
new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212
|
|
|
|
|
|
|
|
|
|
return new_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RangeDataset(MappableDataset):
|
|
|
|
|
"""
|
|
|
|
|