|
|
@ -4940,6 +4940,12 @@ class CelebADataset(MappableDataset):
|
|
|
|
self.shard_id = shard_id
|
|
|
|
self.shard_id = shard_id
|
|
|
|
self.shuffle_level = shuffle
|
|
|
|
self.shuffle_level = shuffle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if usage != "all":
|
|
|
|
|
|
|
|
dir = os.path.realpath(self.dataset_dir)
|
|
|
|
|
|
|
|
partition_file = os.path.join(dir, "list_eval_partition.txt")
|
|
|
|
|
|
|
|
if os.path.exists(partition_file) is False:
|
|
|
|
|
|
|
|
raise RuntimeError("Partition file can not be found when usage is not 'all'.")
|
|
|
|
|
|
|
|
|
|
|
|
def get_args(self):
|
|
|
|
def get_args(self):
|
|
|
|
args = super().get_args()
|
|
|
|
args = super().get_args()
|
|
|
|
args["dataset_dir"] = self.dataset_dir
|
|
|
|
args["dataset_dir"] = self.dataset_dir
|
|
|
|