|
|
|
@ -44,6 +44,7 @@ class SamplerObj;
|
|
|
|
|
class CelebADataset;
|
|
|
|
|
class Cifar10Dataset;
|
|
|
|
|
class Cifar100Dataset;
|
|
|
|
|
class CLUEDataset;
|
|
|
|
|
class CocoDataset;
|
|
|
|
|
class ImageFolderDataset;
|
|
|
|
|
class MnistDataset;
|
|
|
|
@ -91,6 +92,27 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
|
|
|
|
|
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
|
|
|
|
|
std::shared_ptr<SamplerObj> sampler = nullptr);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a CLUEDataset
|
|
|
|
|
/// \notes The generated dataset has a variable number of columns depending on the task and usage
|
|
|
|
|
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
|
|
|
|
|
/// will be sorted in a lexicographical order.
|
|
|
|
|
/// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC").
|
|
|
|
|
/// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
|
|
|
|
|
/// \param[in] num_samples The number of samples to be included in the dataset.
|
|
|
|
|
/// (Default = 0 means all samples.)
|
|
|
|
|
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
|
|
|
|
|
/// Can be any of:
|
|
|
|
|
/// ShuffleMode.kFalse - No shuffling is performed.
|
|
|
|
|
/// ShuffleMode.kFiles - Shuffle files only.
|
|
|
|
|
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
|
|
|
|
|
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
|
|
|
|
|
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
|
|
|
|
/// specified only when num_shards is also specified. (Default = 0)
|
|
|
|
|
/// \return Shared pointer to the current CLUEDataset
|
|
|
|
|
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
|
|
|
|
|
const std::string &usage = "train", int64_t num_samples = 0,
|
|
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a CocoDataset
|
|
|
|
|
/// \notes The generated dataset has multi-columns :
|
|
|
|
|
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
|
|
|
|
@ -289,6 +311,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|
|
|
|
int32_t num_workers_;
|
|
|
|
|
int32_t rows_per_buffer_;
|
|
|
|
|
int32_t connector_que_size_;
|
|
|
|
|
int32_t worker_connector_size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/* ####################################### Derived Dataset classes ################################# */
|
|
|
|
@ -361,6 +384,39 @@ class Cifar100Dataset : public Dataset {
|
|
|
|
|
std::shared_ptr<SamplerObj> sampler_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// \class CLUEDataset
|
|
|
|
|
/// \brief A Dataset derived class to represent CLUE dataset
|
|
|
|
|
class CLUEDataset : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
|
|
|
|
|
ShuffleMode shuffle, int num_shards, int shard_id);
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
|
~CLUEDataset() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
|
|
|
|
/// \return The list of shared pointers to the newly created DatasetOps
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
|
|
|
|
|
|
|
|
|
/// \brief Parameters validation
|
|
|
|
|
/// \return bool true if all the params are valid
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
/// \brief Split string based on a character delimiter
|
|
|
|
|
/// \return A string vector
|
|
|
|
|
std::vector<std::string> split(const std::string &s, char delim);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> dataset_files_;
|
|
|
|
|
std::string task_;
|
|
|
|
|
std::string usage_;
|
|
|
|
|
int64_t num_samples_;
|
|
|
|
|
ShuffleMode shuffle_;
|
|
|
|
|
int num_shards_;
|
|
|
|
|
int shard_id_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CocoDataset : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|