|
|
@ -4935,7 +4935,7 @@ class CSVDataset(SourceDataset):
|
|
|
|
columns as string type.
|
|
|
|
columns as string type.
|
|
|
|
column_names (list[str], optional): List of column names of the dataset (default=None). If this
|
|
|
|
column_names (list[str], optional): List of column names of the dataset (default=None). If this
|
|
|
|
is not provided, infers the column_names from the first row of CSV file.
|
|
|
|
is not provided, infers the column_names from the first row of CSV file.
|
|
|
|
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
|
|
|
|
num_samples (int, optional): number of samples(rows) to read (default=-1, reads the full dataset).
|
|
|
|
num_parallel_workers (int, optional): number of workers to read the data
|
|
|
|
num_parallel_workers (int, optional): number of workers to read the data
|
|
|
|
(default=None, number set in the config).
|
|
|
|
(default=None, number set in the config).
|
|
|
|
shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch
|
|
|
|
shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch
|
|
|
@ -4959,7 +4959,7 @@ class CSVDataset(SourceDataset):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@check_csvdataset
|
|
|
|
@check_csvdataset
|
|
|
|
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
|
|
|
|
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=-1,
|
|
|
|
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
|
|
|
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
|
|
|
super().__init__(num_parallel_workers)
|
|
|
|
super().__init__(num_parallel_workers)
|
|
|
|
self.dataset_files = self._find_files(dataset_files)
|
|
|
|
self.dataset_files = self._find_files(dataset_files)
|
|
|
@ -5010,7 +5010,7 @@ class CSVDataset(SourceDataset):
|
|
|
|
if self._dataset_size is None:
|
|
|
|
if self._dataset_size is None:
|
|
|
|
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
|
|
|
|
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
|
|
|
|
num_rows = get_num_rows(num_rows, self.num_shards)
|
|
|
|
num_rows = get_num_rows(num_rows, self.num_shards)
|
|
|
|
if self.num_samples is None:
|
|
|
|
if self.num_samples == -1:
|
|
|
|
return num_rows
|
|
|
|
return num_rows
|
|
|
|
return min(self.num_samples, num_rows)
|
|
|
|
return min(self.num_samples, num_rows)
|
|
|
|
return self._dataset_size
|
|
|
|
return self._dataset_size
|
|
|
|