|
|
|
@ -75,7 +75,7 @@ class ZipDataset;
|
|
|
|
|
/// will be used to randomly iterate the entire dataset
|
|
|
|
|
/// \return Shared pointer to the current Dataset
|
|
|
|
|
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = nullptr, const bool &decode = false,
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = nullptr, bool decode = false,
|
|
|
|
|
const std::set<std::string> &extensions = {});
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a Cifar10 Dataset
|
|
|
|
@ -84,7 +84,8 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
|
|
|
|
|
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
|
|
|
|
/// will be used to randomly iterate the entire dataset
|
|
|
|
|
/// \return Shared pointer to the current Dataset
|
|
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
|
|
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a Cifar100 Dataset
|
|
|
|
|
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
|
|
|
|
@ -93,7 +94,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
|
|
|
|
|
/// will be used to randomly iterate the entire dataset
|
|
|
|
|
/// \return Shared pointer to the current Dataset
|
|
|
|
|
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
|
|
|
|
|
std::shared_ptr<SamplerObj> sampler = nullptr);
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a CLUEDataset
|
|
|
|
|
/// \notes The generated dataset has a variable number of columns depending on the task and usage
|
|
|
|
@ -114,7 +115,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
|
|
|
|
|
/// \return Shared pointer to the current CLUEDataset
|
|
|
|
|
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
|
|
|
|
|
const std::string &usage = "train", int64_t num_samples = 0,
|
|
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0);
|
|
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
|
|
|
|
int32_t shard_id = 0);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a CocoDataset
|
|
|
|
|
/// \notes The generated dataset has multi-columns :
|
|
|
|
@ -147,10 +149,10 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
|
|
|
|
|
/// \param[in] extensions File extensions to be read
|
|
|
|
|
/// \param[in] class_indexing a class name to label map
|
|
|
|
|
/// \return Shared pointer to the current ImageFolderDataset
|
|
|
|
|
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode = false,
|
|
|
|
|
std::shared_ptr<SamplerObj> sampler = nullptr,
|
|
|
|
|
std::set<std::string> extensions = {},
|
|
|
|
|
std::map<std::string, int32_t> class_indexing = {});
|
|
|
|
|
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false,
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = nullptr,
|
|
|
|
|
const std::set<std::string> &extensions = {},
|
|
|
|
|
const std::map<std::string, int32_t> &class_indexing = {});
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a MnistDataset
|
|
|
|
|
/// \notes The generated dataset has two columns ['image', 'label']
|
|
|
|
@ -158,7 +160,8 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
|
|
|
|
|
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
|
|
|
|
|
/// A `RandomSampler` will be used to randomly iterate the entire dataset
|
|
|
|
|
/// \return Shared pointer to the current MnistDataset
|
|
|
|
|
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
|
|
|
|
|
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a ConcatDataset
|
|
|
|
|
/// \notes Reload "+" operator to concat two datasets
|
|
|
|
@ -183,7 +186,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|
|
|
|
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
|
|
|
|
/// specified only when num_shards is also specified. (Default = 0)
|
|
|
|
|
/// \return Shared pointer to the current TextFileDataset
|
|
|
|
|
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples = 0,
|
|
|
|
|
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples = 0,
|
|
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
|
|
|
|
int32_t shard_id = 0);
|
|
|
|
|
|
|
|
|
@ -202,8 +205,8 @@ std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files
|
|
|
|
|
/// \return Shared pointer to the current Dataset
|
|
|
|
|
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
|
|
|
|
|
const std::string &mode = "train",
|
|
|
|
|
const std::map<std::string, int32_t> &class_index = {}, bool decode = false,
|
|
|
|
|
std::shared_ptr<SamplerObj> sampler = nullptr);
|
|
|
|
|
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a ZipDataset
|
|
|
|
|
/// \notes Applies zip to the dataset
|
|
|
|
@ -417,7 +420,7 @@ class CLUEDataset : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
|
|
|
|
|
ShuffleMode shuffle, int num_shards, int shard_id);
|
|
|
|
|
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
|
~CLUEDataset() = default;
|
|
|
|
@ -440,8 +443,8 @@ class CLUEDataset : public Dataset {
|
|
|
|
|
std::string usage_;
|
|
|
|
|
int64_t num_samples_;
|
|
|
|
|
ShuffleMode shuffle_;
|
|
|
|
|
int num_shards_;
|
|
|
|
|
int shard_id_;
|
|
|
|
|
int32_t num_shards_;
|
|
|
|
|
int32_t shard_id_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CocoDataset : public Dataset {
|
|
|
|
@ -549,7 +552,7 @@ class VOCDataset : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
|
|
|
|
const std::map<std::string, int32_t> &class_index, bool decode, std::shared_ptr<SamplerObj> sampler);
|
|
|
|
|
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
|
~VOCDataset() = default;
|
|
|
|
|