|
|
|
@ -2671,20 +2671,30 @@ class MindDataset(SourceDataset):
|
|
|
|
|
Return:
|
|
|
|
|
Number, number of batches.
|
|
|
|
|
"""
|
|
|
|
|
if self.load_dataset:
|
|
|
|
|
dataset_file = [self.dataset_file]
|
|
|
|
|
else:
|
|
|
|
|
dataset_file = self.dataset_file
|
|
|
|
|
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
|
|
|
|
|
if self.partitions is not None and self.partitions[0] > 0:
|
|
|
|
|
if num_rows % self.partitions[0] == 0:
|
|
|
|
|
num_rows = num_rows // self.partitions[0]
|
|
|
|
|
if self._dataset_size is None:
|
|
|
|
|
if self.load_dataset:
|
|
|
|
|
dataset_file = [self.dataset_file]
|
|
|
|
|
else:
|
|
|
|
|
if self.num_padded > 0:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Dataset size plus number of padded samples is not divisible by number of shards.")
|
|
|
|
|
num_rows = num_rows // self.partitions[0] + 1
|
|
|
|
|
return num_rows
|
|
|
|
|
dataset_file = self.dataset_file
|
|
|
|
|
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
|
|
|
|
|
if self.partitions is not None and self.partitions[0] > 0:
|
|
|
|
|
if num_rows % self.partitions[0] == 0:
|
|
|
|
|
num_rows = num_rows // self.partitions[0]
|
|
|
|
|
else:
|
|
|
|
|
if self.num_padded > 0:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Dataset size plus number of padded samples is not divisible by number of shards.")
|
|
|
|
|
num_rows = num_rows // self.partitions[0] + 1
|
|
|
|
|
return num_rows
|
|
|
|
|
return self._dataset_size
|
|
|
|
|
|
|
|
|
|
# manually set dataset_size as a tempoary solution.
|
|
|
|
|
def set_dataset_size(self, value):
|
|
|
|
|
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
|
|
|
|
if value >= 0:
|
|
|
|
|
self._dataset_size = value
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('set dataset_size with negative value {}'.format(value))
|
|
|
|
|
|
|
|
|
|
def is_shuffled(self):
|
|
|
|
|
if self.shuffle_option is None:
|
|
|
|
|