diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index b156ac0003..c43fe8e611 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -4940,6 +4940,12 @@ class CelebADataset(MappableDataset): self.shard_id = shard_id self.shuffle_level = shuffle + if usage != "all": + dir = os.path.realpath(self.dataset_dir) + partition_file = os.path.join(dir, "list_eval_partition.txt") + if os.path.exists(partition_file) is False: + raise RuntimeError("Partition file can not be found when usage is not 'all'.") + def get_args(self): args = super().get_args() args["dataset_dir"] = self.dataset_dir