|
|
|
|
@ -988,7 +988,7 @@ class Dataset:
|
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
@check_device_send
|
|
|
|
|
def device_que(self, prefetch_size=None, send_epoch_end=True):
|
|
|
|
|
def device_que(self, prefetch_size=None, send_epoch_end=True, create_data_info_queue=False):
|
|
|
|
|
"""
|
|
|
|
|
Return a transferred Dataset that transfers data through a device.
|
|
|
|
|
|
|
|
|
|
@ -996,6 +996,8 @@ class Dataset:
|
|
|
|
|
prefetch_size (int, optional): Prefetch number of records ahead of the
|
|
|
|
|
user's request (default=None).
|
|
|
|
|
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
|
|
|
|
create_data_info_queue (bool, optional): Whether to create queue which stores
|
|
|
|
|
types and shapes of data or not(default=False).
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
If device is Ascend, features of data will be transferred one by one. The limitation
|
|
|
|
|
@ -1004,15 +1006,17 @@ class Dataset:
|
|
|
|
|
Return:
|
|
|
|
|
TransferDataset, dataset for transferring.
|
|
|
|
|
"""
|
|
|
|
|
return self.to_device(send_epoch_end=send_epoch_end)
|
|
|
|
|
return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
|
|
|
|
|
|
|
|
|
|
@check_device_send
|
|
|
|
|
def to_device(self, send_epoch_end=True):
|
|
|
|
|
def to_device(self, send_epoch_end=True, create_data_info_queue=False):
|
|
|
|
|
"""
|
|
|
|
|
Transfer data through CPU, GPU or Ascend devices.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
|
|
|
|
create_data_info_queue (bool, optional): Whether to create queue which stores
|
|
|
|
|
types and shapes of data or not(default=False).
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
If device is Ascend, features of data will be transferred one by one. The limitation
|
|
|
|
|
@ -1061,7 +1065,7 @@ class Dataset:
|
|
|
|
|
|
|
|
|
|
distribution_path, device_id = get_distribution(self)
|
|
|
|
|
if distribution_path == "":
|
|
|
|
|
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
|
|
|
|
|
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end, create_data_info_queue)
|
|
|
|
|
try:
|
|
|
|
|
with open(distribution_path, 'r') as distribution_f:
|
|
|
|
|
dist = json.load(distribution_f)
|
|
|
|
|
@ -1071,7 +1075,7 @@ class Dataset:
|
|
|
|
|
except Exception:
|
|
|
|
|
raise RuntimeError("Distribution file failed to read")
|
|
|
|
|
|
|
|
|
|
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
|
|
|
|
|
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end, create_data_info_queue)
|
|
|
|
|
|
|
|
|
|
@check_save
|
|
|
|
|
def save(self, file_name, num_files=1, file_type='mindrecord'):
|
|
|
|
|
@ -1775,6 +1779,25 @@ class BatchDataset(DatasetOp):
|
|
|
|
|
for input_dataset in dataset.children:
|
|
|
|
|
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
|
if id(self) in memodict:
|
|
|
|
|
return memodict[id(self)]
|
|
|
|
|
cls = self.__class__
|
|
|
|
|
new_op = cls.__new__(cls)
|
|
|
|
|
memodict[id(self)] = new_op
|
|
|
|
|
new_op.children = copy.deepcopy(self.children, memodict)
|
|
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict)
|
|
|
|
|
new_op.num_parallel_workers = self.num_parallel_workers
|
|
|
|
|
new_op.batch_size = self.batch_size
|
|
|
|
|
new_op.drop_remainder = self.drop_remainder
|
|
|
|
|
new_op.per_batch_map = self.per_batch_map
|
|
|
|
|
new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
|
|
|
|
|
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
|
|
|
|
|
new_op.column_order = copy.deepcopy(self.column_order, memodict)
|
|
|
|
|
new_op.pad_info = copy.deepcopy(self.pad_info, memodict)
|
|
|
|
|
new_op._input_indexs = self._input_indexs # pylint: disable=W0212
|
|
|
|
|
return new_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchInfo(CBatchInfo):
|
|
|
|
|
"""
|
|
|
|
|
@ -2600,9 +2623,12 @@ class TransferDataset(DatasetOp):
|
|
|
|
|
device_id (int): ID of device.
|
|
|
|
|
device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
|
|
|
|
|
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
|
|
|
|
create_data_info_queue (bool, optional): Whether to create queue which stores
|
|
|
|
|
types and shapes of data or not(default=False).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True):
|
|
|
|
|
def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True,
|
|
|
|
|
create_data_info_queue=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.children.append(input_dataset)
|
|
|
|
|
input_dataset.parent.append(self)
|
|
|
|
|
@ -2612,6 +2638,7 @@ class TransferDataset(DatasetOp):
|
|
|
|
|
self._device_id = device_id
|
|
|
|
|
self._send_epoch_end = send_epoch_end
|
|
|
|
|
self.iterator = None
|
|
|
|
|
self._create_data_info_queue = create_data_info_queue
|
|
|
|
|
|
|
|
|
|
def get_args(self):
|
|
|
|
|
args = super().get_args()
|
|
|
|
|
@ -2619,6 +2646,7 @@ class TransferDataset(DatasetOp):
|
|
|
|
|
args["device_type"] = self._device_type
|
|
|
|
|
args["device_id"] = self._device_id
|
|
|
|
|
args["send_epoch_end"] = self._send_epoch_end
|
|
|
|
|
args["create_data_info_queue"] = self._create_data_info_queue
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
|
|
|
|
@ -2650,6 +2678,27 @@ class TransferDataset(DatasetOp):
|
|
|
|
|
def continue_send(self):
|
|
|
|
|
self.iterator.depipeline.ContinueSend()
|
|
|
|
|
|
|
|
|
|
def get_data_info(self):
|
|
|
|
|
return self.iterator.depipeline.GetDataInfo()
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict):
|
|
|
|
|
if id(self) in memodict:
|
|
|
|
|
return memodict[id(self)]
|
|
|
|
|
cls = self.__class__
|
|
|
|
|
new_op = cls.__new__(cls)
|
|
|
|
|
memodict[id(self)] = new_op
|
|
|
|
|
new_op.children = copy.deepcopy(self.children, memodict)
|
|
|
|
|
new_op.parent = copy.deepcopy(self.parent, memodict)
|
|
|
|
|
new_op.num_parallel_workers = self.num_parallel_workers
|
|
|
|
|
new_op.queue_name = self.queue_name
|
|
|
|
|
new_op._device_type = self._device_type # pylint: disable=W0212
|
|
|
|
|
new_op._device_id = self._device_id # pylint: disable=W0212
|
|
|
|
|
new_op._input_indexs = self._input_indexs # pylint: disable=W0212
|
|
|
|
|
new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212
|
|
|
|
|
new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212
|
|
|
|
|
|
|
|
|
|
return new_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RangeDataset(MappableDataset):
|
|
|
|
|
"""
|
|
|
|
|
|