|
|
|
@ -23,6 +23,7 @@
|
|
|
|
|
// Source dataset headers (in alphabetical order)
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
|
|
|
@ -128,6 +129,16 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::s
|
|
|
|
|
return ds->ValidateParams() ? ds : nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to create a CLUEDataset.
|
|
|
|
|
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
|
|
|
|
|
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int num_shards,
|
|
|
|
|
int shard_id) {
|
|
|
|
|
auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id);
|
|
|
|
|
|
|
|
|
|
// Call derived class validation method.
|
|
|
|
|
return ds->ValidateParams() ? ds : nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to create a CocoDataset.
|
|
|
|
|
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
|
|
|
|
|
const std::string &task, const bool &decode,
|
|
|
|
@ -501,6 +512,206 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
|
|
|
|
|
return node_ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Constructor for CLUEDataset
|
|
|
|
|
CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string task, std::string usage,
|
|
|
|
|
int64_t num_samples, ShuffleMode shuffle, int num_shards, int shard_id)
|
|
|
|
|
: dataset_files_(clue_files),
|
|
|
|
|
task_(task),
|
|
|
|
|
usage_(usage),
|
|
|
|
|
num_samples_(num_samples),
|
|
|
|
|
shuffle_(shuffle),
|
|
|
|
|
num_shards_(num_shards),
|
|
|
|
|
shard_id_(shard_id) {}
|
|
|
|
|
|
|
|
|
|
bool CLUEDataset::ValidateParams() {
|
|
|
|
|
if (dataset_files_.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "CLUEDataset: dataset_files is not specified.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto f : dataset_files_) {
|
|
|
|
|
Path clue_file(f);
|
|
|
|
|
if (!clue_file.Exists()) {
|
|
|
|
|
MS_LOG(ERROR) << "dataset file: [" << f << "] is invalid or not exist";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
|
|
|
|
|
std::vector<std::string> usage_list = {"train", "test", "eval"};
|
|
|
|
|
|
|
|
|
|
if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "usage should be train, test or eval.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (num_samples_ < 0) {
|
|
|
|
|
MS_LOG(ERROR) << "CLUEDataset: Invalid number of samples: " << num_samples_;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (num_shards_ <= 0) {
|
|
|
|
|
MS_LOG(ERROR) << "CLUEDataset: Invalid num_shards: " << num_shards_;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
|
|
|
|
MS_LOG(ERROR) << "CLUEDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to split string based on a character delimiter
|
|
|
|
|
std::vector<std::string> CLUEDataset::split(const std::string &s, char delim) {
|
|
|
|
|
std::vector<std::string> res;
|
|
|
|
|
std::stringstream ss(s);
|
|
|
|
|
std::string item;
|
|
|
|
|
|
|
|
|
|
while (getline(ss, item, delim)) {
|
|
|
|
|
res.push_back(item);
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to build CLUEDataset
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
|
|
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
|
|
|
|
std::map<std::string, std::string> key_map;
|
|
|
|
|
if (task_ == "AFQMC") {
|
|
|
|
|
if (usage_ == "train") {
|
|
|
|
|
key_map["sentence1"] = "sentence1";
|
|
|
|
|
key_map["sentence2"] = "sentence2";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
} else if (usage_ == "test") {
|
|
|
|
|
key_map["id"] = "id";
|
|
|
|
|
key_map["sentence1"] = "sentence1";
|
|
|
|
|
key_map["sentence2"] = "sentence2";
|
|
|
|
|
} else if (usage_ == "eval") {
|
|
|
|
|
key_map["sentence1"] = "sentence1";
|
|
|
|
|
key_map["sentence2"] = "sentence2";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
}
|
|
|
|
|
} else if (task_ == "CMNLI") {
|
|
|
|
|
if (usage_ == "train") {
|
|
|
|
|
key_map["sentence1"] = "sentence1";
|
|
|
|
|
key_map["sentence2"] = "sentence2";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
} else if (usage_ == "test") {
|
|
|
|
|
key_map["id"] = "id";
|
|
|
|
|
key_map["sentence1"] = "sentence1";
|
|
|
|
|
key_map["sentence2"] = "sentence2";
|
|
|
|
|
} else if (usage_ == "eval") {
|
|
|
|
|
key_map["sentence1"] = "sentence1";
|
|
|
|
|
key_map["sentence2"] = "sentence2";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
}
|
|
|
|
|
} else if (task_ == "CSL") {
|
|
|
|
|
if (usage_ == "train") {
|
|
|
|
|
key_map["id"] = "id";
|
|
|
|
|
key_map["abst"] = "abst";
|
|
|
|
|
key_map["keyword"] = "keyword";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
} else if (usage_ == "test") {
|
|
|
|
|
key_map["id"] = "id";
|
|
|
|
|
key_map["abst"] = "abst";
|
|
|
|
|
key_map["keyword"] = "keyword";
|
|
|
|
|
} else if (usage_ == "eval") {
|
|
|
|
|
key_map["id"] = "id";
|
|
|
|
|
key_map["abst"] = "abst";
|
|
|
|
|
key_map["keyword"] = "keyword";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
}
|
|
|
|
|
} else if (task_ == "IFLYTEK") {
|
|
|
|
|
if (usage_ == "train") {
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
key_map["label_des"] = "label_des";
|
|
|
|
|
key_map["sentence"] = "sentence";
|
|
|
|
|
} else if (usage_ == "test") {
|
|
|
|
|
key_map["id"] = "id";
|
|
|
|
|
key_map["sentence"] = "sentence";
|
|
|
|
|
} else if (usage_ == "eval") {
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
key_map["label_des"] = "label_des";
|
|
|
|
|
key_map["sentence"] = "sentence";
|
|
|
|
|
}
|
|
|
|
|
} else if (task_ == "TNEWS") {
|
|
|
|
|
if (usage_ == "train") {
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
key_map["label_desc"] = "label_desc";
|
|
|
|
|
key_map["sentence"] = "sentence";
|
|
|
|
|
key_map["keywords"] = "keywords";
|
|
|
|
|
} else if (usage_ == "test") {
|
|
|
|
|
key_map["id"] = "id";
|
|
|
|
|
key_map["sentence"] = "sentence";
|
|
|
|
|
key_map["keywords"] = "keywords";
|
|
|
|
|
} else if (usage_ == "eval") {
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
key_map["label_desc"] = "label_desc";
|
|
|
|
|
key_map["sentence"] = "sentence";
|
|
|
|
|
key_map["keywords"] = "keywords";
|
|
|
|
|
}
|
|
|
|
|
} else if (task_ == "WSC") {
|
|
|
|
|
if (usage_ == "train") {
|
|
|
|
|
key_map["span1_index"] = "target/span1_index";
|
|
|
|
|
key_map["span2_index"] = "target/span2_index";
|
|
|
|
|
key_map["span1_text"] = "target/span1_text";
|
|
|
|
|
key_map["span2_text"] = "target/span2_text";
|
|
|
|
|
key_map["idx"] = "idx";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
key_map["text"] = "text";
|
|
|
|
|
} else if (usage_ == "test") {
|
|
|
|
|
key_map["span1_index"] = "target/span1_index";
|
|
|
|
|
key_map["span2_index"] = "target/span2_index";
|
|
|
|
|
key_map["span1_text"] = "target/span1_text";
|
|
|
|
|
key_map["span2_text"] = "target/span2_text";
|
|
|
|
|
key_map["idx"] = "idx";
|
|
|
|
|
key_map["text"] = "text";
|
|
|
|
|
} else if (usage_ == "eval") {
|
|
|
|
|
key_map["span1_index"] = "target/span1_index";
|
|
|
|
|
key_map["span2_index"] = "target/span2_index";
|
|
|
|
|
key_map["span1_text"] = "target/span1_text";
|
|
|
|
|
key_map["span2_text"] = "target/span2_text";
|
|
|
|
|
key_map["idx"] = "idx";
|
|
|
|
|
key_map["label"] = "label";
|
|
|
|
|
key_map["text"] = "text";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ColKeyMap ck_map;
|
|
|
|
|
for (auto &p : key_map) {
|
|
|
|
|
ck_map.insert({p.first, split(p.second, '/')});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
|
|
|
|
std::shared_ptr<ClueOp> clue_op =
|
|
|
|
|
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
|
|
|
|
|
dataset_files_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
|
|
|
|
clue_op->Init();
|
|
|
|
|
if (shuffle_ == ShuffleMode::kGlobal) {
|
|
|
|
|
// Inject ShuffleOp
|
|
|
|
|
int64_t shuffle_size = 0;
|
|
|
|
|
int64_t num_rows = 0;
|
|
|
|
|
|
|
|
|
|
// First, get the number of rows in the datset and then compute the shuffle size
|
|
|
|
|
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
|
|
|
|
|
shuffle_size = ComputeShuffleSize(dataset_files_.size(), num_shards_, num_rows, 0);
|
|
|
|
|
MS_LOG(INFO) << "CLUEDataset::Build - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
|
|
|
|
|
std::shared_ptr<DatasetOp> op =
|
|
|
|
|
std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), worker_connector_size_, true, rows_per_buffer_);
|
|
|
|
|
node_ops.push_back(op);
|
|
|
|
|
}
|
|
|
|
|
node_ops.push_back(clue_op);
|
|
|
|
|
return node_ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Constructor for CocoDataset
|
|
|
|
|
CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
|
|
|
|
const bool &decode, const std::shared_ptr<SamplerObj> &sampler)
|
|
|
|
|