!9053 [MD] fix bug in dataset deepcopy

From: @liyong126
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
pull/9053/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 2ac9eafebb

@ -1889,6 +1889,26 @@ class BatchDataset(Dataset):
for input_dataset in dataset.children:
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.children = copy.deepcopy(self.children, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.num_parallel_workers = self.num_parallel_workers
new_op.batch_size = self.batch_size
new_op.batch_size_func = self.batch_size_func
new_op.drop_remainder = self.drop_remainder
new_op.per_batch_map = self.per_batch_map
new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
new_op.column_order = copy.deepcopy(self.column_order, memodict)
new_op.pad = self.pad
new_op.pad_info = copy.deepcopy(self.pad_info, memodict)
return new_op
class BatchInfo(cde.CBatchInfo):
"""
@ -2753,6 +2773,22 @@ class TransferDataset(Dataset):
if self._to_device is not None:
self._to_device.release()
def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.children = copy.deepcopy(self.children, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.num_parallel_workers = self.num_parallel_workers
new_op.queue_name = self.queue_name
new_op.device_type = self.device_type
new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212
new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212
return new_op
class RangeDataset(MappableDataset):
"""

@ -312,6 +312,7 @@ class PKSampler(BuiltinSampler):
Args:
num_val (int): Number of elements to sample for each class.
num_class (int, optional): Number of classes to sample (default=None, all classes).
The parameter does not supported to specify currently.
shuffle (bool, optional): If True, the class IDs are shuffled (default=False).
class_column (str, optional): Name of column with class labels for MindDataset (default='label').
num_samples (int, optional): The number of samples to draw (default=None, all elements).

Loading…
Cancel
Save