c++ api support for CLUE Dataset

pull/3819/head
jiangzhiwen 5 years ago
parent f92735d14a
commit 5842a0212f

File diff suppressed because it is too large Load Diff

@ -35,6 +35,9 @@ enum class DatasetType { kUnknown, kArrow, kTf };
// Possible flavours of Tensor implementations
enum class TensorImpl { kNone, kFlexible, kCv, kNP };
// Possible values for shuffle
enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 };
// Possible values for Border types
enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 };

@ -267,7 +267,7 @@ class ClueOp : public ParallelOp {
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> clue_files_list_;
WaitPost io_block_queue_wait_post_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
ColKeyMap cols_to_keyword_;

@ -44,6 +44,7 @@ class SamplerObj;
class CelebADataset;
class Cifar10Dataset;
class Cifar100Dataset;
class CLUEDataset;
class CocoDataset;
class ImageFolderDataset;
class MnistDataset;
@ -91,6 +92,27 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a CLUEDataset
/// \notes The generated dataset has a variable number of columns depending on the task and usage
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC").
/// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current CLUEDataset
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
const std::string &usage = "train", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0);
/// \brief Function to create a CocoDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
@ -289,6 +311,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t connector_que_size_;
int32_t worker_connector_size_;
};
/* ####################################### Derived Dataset classes ################################# */
@ -361,6 +384,39 @@ class Cifar100Dataset : public Dataset {
std::shared_ptr<SamplerObj> sampler_;
};
/// \class CLUEDataset
/// \brief A Dataset derived class to represent CLUE dataset
class CLUEDataset : public Dataset {
public:
/// \brief Constructor
CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int num_shards, int shard_id);
/// \brief Destructor
~CLUEDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
/// \brief Split string based on a character delimiter
/// \return A string vector
std::vector<std::string> split(const std::string &s, char delim);
std::vector<std::string> dataset_files_;
std::string task_;
std::string usage_;
int64_t num_samples_;
ShuffleMode shuffle_;
int num_shards_;
int shard_id_;
};
class CocoDataset : public Dataset {
public:
/// \brief Constructor

@ -96,6 +96,7 @@ SET(DE_UT_SRCS
c_api_transforms_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_cifar_test.cc
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save