!3819 C++ API Support for CLUE Dataset

Merge pull request !3819 from jiangzhiwen/jzw/c_api_clue
pull/3819/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 5adba834d0

@ -23,6 +23,7 @@
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
@ -128,6 +129,16 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::s
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a CLUEDataset.
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int num_shards,
int shard_id) {
auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a CocoDataset.
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode,
@ -501,6 +512,206 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
return node_ops;
}
// Constructor for CLUEDataset
CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string task, std::string usage,
int64_t num_samples, ShuffleMode shuffle, int num_shards, int shard_id)
: dataset_files_(clue_files),
task_(task),
usage_(usage),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {}
bool CLUEDataset::ValidateParams() {
if (dataset_files_.empty()) {
MS_LOG(ERROR) << "CLUEDataset: dataset_files is not specified.";
return false;
}
for (auto f : dataset_files_) {
Path clue_file(f);
if (!clue_file.Exists()) {
MS_LOG(ERROR) << "dataset file: [" << f << "] is invalid or not exist";
return false;
}
}
std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
std::vector<std::string> usage_list = {"train", "test", "eval"};
if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) {
MS_LOG(ERROR) << "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL.";
return false;
}
if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) {
MS_LOG(ERROR) << "usage should be train, test or eval.";
return false;
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "CLUEDataset: Invalid number of samples: " << num_samples_;
return false;
}
if (num_shards_ <= 0) {
MS_LOG(ERROR) << "CLUEDataset: Invalid num_shards: " << num_shards_;
return false;
}
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
MS_LOG(ERROR) << "CLUEDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
return false;
}
return true;
}
// Function to split string based on a character delimiter
std::vector<std::string> CLUEDataset::split(const std::string &s, char delim) {
std::vector<std::string> res;
std::stringstream ss(s);
std::string item;
while (getline(ss, item, delim)) {
res.push_back(item);
}
return res;
}
// Function to build CLUEDataset
std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::map<std::string, std::string> key_map;
if (task_ == "AFQMC") {
if (usage_ == "train") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
} else if (usage_ == "test") {
key_map["id"] = "id";
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
} else if (usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
}
} else if (task_ == "CMNLI") {
if (usage_ == "train") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
} else if (usage_ == "test") {
key_map["id"] = "id";
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
} else if (usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
}
} else if (task_ == "CSL") {
if (usage_ == "train") {
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
key_map["label"] = "label";
} else if (usage_ == "test") {
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
} else if (usage_ == "eval") {
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
key_map["label"] = "label";
}
} else if (task_ == "IFLYTEK") {
if (usage_ == "train") {
key_map["label"] = "label";
key_map["label_des"] = "label_des";
key_map["sentence"] = "sentence";
} else if (usage_ == "test") {
key_map["id"] = "id";
key_map["sentence"] = "sentence";
} else if (usage_ == "eval") {
key_map["label"] = "label";
key_map["label_des"] = "label_des";
key_map["sentence"] = "sentence";
}
} else if (task_ == "TNEWS") {
if (usage_ == "train") {
key_map["label"] = "label";
key_map["label_desc"] = "label_desc";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
} else if (usage_ == "test") {
key_map["id"] = "id";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
} else if (usage_ == "eval") {
key_map["label"] = "label";
key_map["label_desc"] = "label_desc";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
}
} else if (task_ == "WSC") {
if (usage_ == "train") {
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["label"] = "label";
key_map["text"] = "text";
} else if (usage_ == "test") {
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["text"] = "text";
} else if (usage_ == "eval") {
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["label"] = "label";
key_map["text"] = "text";
}
}
ColKeyMap ck_map;
for (auto &p : key_map) {
ck_map.insert({p.first, split(p.second, '/')});
}
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
std::shared_ptr<ClueOp> clue_op =
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
dataset_files_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
clue_op->Init();
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
int64_t shuffle_size = 0;
int64_t num_rows = 0;
// First, get the number of rows in the datset and then compute the shuffle size
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
shuffle_size = ComputeShuffleSize(dataset_files_.size(), num_shards_, num_rows, 0);
MS_LOG(INFO) << "CLUEDataset::Build - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
std::shared_ptr<DatasetOp> op =
std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), worker_connector_size_, true, rows_per_buffer_);
node_ops.push_back(op);
}
node_ops.push_back(clue_op);
return node_ops;
}
// Constructor for CocoDataset
CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler)

@ -267,7 +267,7 @@ class ClueOp : public ParallelOp {
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> clue_files_list_;
WaitPost io_block_queue_wait_post_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
ColKeyMap cols_to_keyword_;

@ -45,6 +45,7 @@ class SamplerObj;
class CelebADataset;
class Cifar10Dataset;
class Cifar100Dataset;
class CLUEDataset;
class CocoDataset;
class ImageFolderDataset;
class MnistDataset;
@ -93,6 +94,27 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a CLUEDataset
/// \notes The generated dataset has a variable number of columns depending on the task and usage
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC").
/// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current CLUEDataset
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
const std::string &usage = "train", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0);
/// \brief Function to create a CocoDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
@ -388,6 +410,39 @@ class Cifar100Dataset : public Dataset {
std::shared_ptr<SamplerObj> sampler_;
};
/// \class CLUEDataset
/// \brief A Dataset derived class to represent CLUE dataset
class CLUEDataset : public Dataset {
public:
/// \brief Constructor
CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int num_shards, int shard_id);
/// \brief Destructor
~CLUEDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
/// \brief Split string based on a character delimiter
/// \return A string vector
std::vector<std::string> split(const std::string &s, char delim);
std::vector<std::string> dataset_files_;
std::string task_;
std::string usage_;
int64_t num_samples_;
ShuffleMode shuffle_;
int num_shards_;
int shard_id_;
};
class CocoDataset : public Dataset {
public:
/// \brief Constructor

@ -97,6 +97,7 @@ SET(DE_UT_SRCS
c_api_transforms_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_cifar_test.cc
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_filetext_test.cc
c_api_dataset_voc_test.cc

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save