diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index f3703b3850..ca88597713 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -4009,6 +4009,31 @@ class CelebADataset(MappableDataset): args["shard_id"] = self.shard_id return args + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + if self._dataset_size is None: + dir = os.path.realpath(self.dataset_dir) + attr_file = os.path.join(dir, "list_attr_celeba.txt") + num_rows = '' + try: + with open(attr_file, 'r') as f: + num_rows = int(f.readline()) + except Exception: + raise RuntimeError("Get dataset size failed from attribution file.") + rows_per_shard = get_num_rows(num_rows, self.num_shards) + if self.num_samples is not None: + rows_per_shard = min(self.num_samples, rows_per_shard) + rows_from_sampler = self._get_sampler_dataset_size() + if rows_from_sampler is None: + return rows_per_shard + return min(rows_from_sampler, rows_per_shard) + return self._dataset_size + def is_shuffled(self): if self.shuffle_level is None: return True diff --git a/tests/ut/python/dataset/test_datasets_celeba.py b/tests/ut/python/dataset/test_datasets_celeba.py index accf9730f3..26f18e8772 100644 --- a/tests/ut/python/dataset/test_datasets_celeba.py +++ b/tests/ut/python/dataset/test_datasets_celeba.py @@ -85,9 +85,14 @@ def test_celeba_dataset_distribute(): count = count + 1 assert count == 1 +def test_celeba_get_dataset_size(): + data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False) + size = data.get_dataset_size() + assert size == 2 if __name__ == '__main__': test_celeba_dataset_label() test_celeba_dataset_op() test_celeba_dataset_ext() test_celeba_dataset_distribute() + test_celeba_get_dataset_size()