|
|
@ -4009,6 +4009,31 @@ class CelebADataset(MappableDataset):
|
|
|
|
args["shard_id"] = self.shard_id
|
|
|
|
args["shard_id"] = self.shard_id
|
|
|
|
return args
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_size(self):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get the number of batches in an epoch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
|
|
|
Number, number of batches.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if self._dataset_size is None:
|
|
|
|
|
|
|
|
dir = os.path.realpath(self.dataset_dir)
|
|
|
|
|
|
|
|
attr_file = os.path.join(dir, "list_attr_celeba.txt")
|
|
|
|
|
|
|
|
num_rows = ''
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
with open(attr_file, 'r') as f:
|
|
|
|
|
|
|
|
num_rows = int(f.readline())
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
|
|
raise RuntimeError("Get dataset size failed from attribution file.")
|
|
|
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
|
|
|
|
|
|
|
if self.num_samples is not None:
|
|
|
|
|
|
|
|
rows_per_shard = min(self.num_samples, rows_per_shard)
|
|
|
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size()
|
|
|
|
|
|
|
|
if rows_from_sampler is None:
|
|
|
|
|
|
|
|
return rows_per_shard
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard)
|
|
|
|
|
|
|
|
return self._dataset_size
|
|
|
|
|
|
|
|
|
|
|
|
def is_shuffled(self):
|
|
|
|
def is_shuffled(self):
|
|
|
|
if self.shuffle_level is None:
|
|
|
|
if self.shuffle_level is None:
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|