C++ API: Provide validate param functions

pull/4480/head
Cathy Wong 5 years ago
parent 7d70fb4dc4
commit 9c8af0d1cf

File diff suppressed because it is too large Load Diff

@ -75,7 +75,7 @@ class ZipDataset;
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
const std::shared_ptr<SamplerObj> &sampler = nullptr, const bool &decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr, bool decode = false,
const std::set<std::string> &extensions = {});
/// \brief Function to create a Cifar10 Dataset
@ -84,7 +84,8 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a Cifar100 Dataset
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
@ -93,7 +94,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
std::shared_ptr<SamplerObj> sampler = nullptr);
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a CLUEDataset
/// \notes The generated dataset has a variable number of columns depending on the task and usage
@ -114,7 +115,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
/// \return Shared pointer to the current CLUEDataset
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
const std::string &usage = "train", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0);
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
/// \brief Function to create a CocoDataset
/// \notes The generated dataset has multi-columns :
@ -147,10 +149,10 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
/// \param[in] extensions File extensions to be read
/// \param[in] class_indexing a class name to label map
/// \return Shared pointer to the current ImageFolderDataset
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode = false,
std::shared_ptr<SamplerObj> sampler = nullptr,
std::set<std::string> extensions = {},
std::map<std::string, int32_t> class_indexing = {});
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr,
const std::set<std::string> &extensions = {},
const std::map<std::string, int32_t> &class_indexing = {});
/// \brief Function to create a MnistDataset
/// \notes The generated dataset has two columns ['image', 'label']
@ -158,7 +160,8 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current MnistDataset
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a ConcatDataset
/// \notes Reload "+" operator to concat two datasets
@ -183,7 +186,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current TextFileDataset
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples = 0,
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
@ -202,8 +205,8 @@ std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files
/// \return Shared pointer to the current Dataset
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
const std::string &mode = "train",
const std::map<std::string, int32_t> &class_index = {}, bool decode = false,
std::shared_ptr<SamplerObj> sampler = nullptr);
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a ZipDataset
/// \notes Applies zip to the dataset
@ -417,7 +420,7 @@ class CLUEDataset : public Dataset {
public:
/// \brief Constructor
CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int num_shards, int shard_id);
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
/// \brief Destructor
~CLUEDataset() = default;
@ -440,8 +443,8 @@ class CLUEDataset : public Dataset {
std::string usage_;
int64_t num_samples_;
ShuffleMode shuffle_;
int num_shards_;
int shard_id_;
int32_t num_shards_;
int32_t shard_id_;
};
class CocoDataset : public Dataset {
@ -549,7 +552,7 @@ class VOCDataset : public Dataset {
public:
/// \brief Constructor
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode, std::shared_ptr<SamplerObj> sampler);
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~VOCDataset() = default;

@ -111,7 +111,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetFail3) {
// Attempt to create a TextFile Dataset
// with non-existent dataset_files input
std::shared_ptr<Dataset> ds = TextFile({"notexist.txt"}, 0, ShuffleMode::kFalse);
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
std::shared_ptr<Dataset> ds = TextFile({tf_file1, "notexist.txt"}, 0, ShuffleMode::kFalse);
// Expect failure: specified dataset_files does not exist
EXPECT_EQ(ds, nullptr);

Loading…
Cancel
Save