From e0e167a000d5a338b98a5874c9a3ddde566eae35 Mon Sep 17 00:00:00 2001 From: jiangzhiwen Date: Thu, 28 May 2020 19:30:24 +0800 Subject: [PATCH] add CLUE dataset --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 46 +- mindspore/ccsrc/dataset/api/de_pipeline.h | 5 +- .../ccsrc/dataset/api/python_bindings.cc | 16 +- .../engine/datasetops/source/CMakeLists.txt | 1 + .../engine/datasetops/source/clue_op.cc | 551 ++++++++++++++++++ .../engine/datasetops/source/clue_op.h | 270 +++++++++ .../engine/datasetops/source/text_file_op.cc | 2 +- mindspore/dataset/__init__.py | 8 +- mindspore/dataset/engine/__init__.py | 2 +- mindspore/dataset/engine/datasets.py | 220 ++++++- mindspore/dataset/engine/iterators.py | 5 +- mindspore/dataset/engine/validators.py | 35 ++ tests/ut/cpp/dataset/CMakeLists.txt | 1 + tests/ut/cpp/dataset/clue_op_test.cc | 117 ++++ tests/ut/data/dataset/testCLUE/afqmc/dev.json | 3 + .../ut/data/dataset/testCLUE/afqmc/test.json | 3 + .../ut/data/dataset/testCLUE/afqmc/train.json | 3 + tests/ut/data/dataset/testCLUE/cmnli/dev.json | 3 + .../ut/data/dataset/testCLUE/cmnli/test.json | 3 + .../ut/data/dataset/testCLUE/cmnli/train.json | 3 + tests/ut/data/dataset/testCLUE/csl/dev.json | 3 + tests/ut/data/dataset/testCLUE/csl/test.json | 3 + tests/ut/data/dataset/testCLUE/csl/train.json | 3 + .../ut/data/dataset/testCLUE/iflytek/dev.json | 3 + .../data/dataset/testCLUE/iflytek/test.json | 3 + .../data/dataset/testCLUE/iflytek/train.json | 3 + tests/ut/data/dataset/testCLUE/tnews/dev.json | 3 + .../ut/data/dataset/testCLUE/tnews/test.json | 3 + .../ut/data/dataset/testCLUE/tnews/train.json | 3 + tests/ut/data/dataset/testCLUE/wsc/dev.json | 3 + tests/ut/data/dataset/testCLUE/wsc/test.json | 3 + tests/ut/data/dataset/testCLUE/wsc/train.json | 3 + tests/ut/python/dataset/test_datasets_clue.py | 355 +++++++++++ 33 files changed, 1676 insertions(+), 12 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h create mode 100644 tests/ut/cpp/dataset/clue_op_test.cc create mode 100644 tests/ut/data/dataset/testCLUE/afqmc/dev.json create mode 100644 tests/ut/data/dataset/testCLUE/afqmc/test.json create mode 100644 tests/ut/data/dataset/testCLUE/afqmc/train.json create mode 100644 tests/ut/data/dataset/testCLUE/cmnli/dev.json create mode 100644 tests/ut/data/dataset/testCLUE/cmnli/test.json create mode 100644 tests/ut/data/dataset/testCLUE/cmnli/train.json create mode 100644 tests/ut/data/dataset/testCLUE/csl/dev.json create mode 100644 tests/ut/data/dataset/testCLUE/csl/test.json create mode 100644 tests/ut/data/dataset/testCLUE/csl/train.json create mode 100644 tests/ut/data/dataset/testCLUE/iflytek/dev.json create mode 100644 tests/ut/data/dataset/testCLUE/iflytek/test.json create mode 100644 tests/ut/data/dataset/testCLUE/iflytek/train.json create mode 100644 tests/ut/data/dataset/testCLUE/tnews/dev.json create mode 100644 tests/ut/data/dataset/testCLUE/tnews/test.json create mode 100644 tests/ut/data/dataset/testCLUE/tnews/train.json create mode 100755 tests/ut/data/dataset/testCLUE/wsc/dev.json create mode 100755 tests/ut/data/dataset/testCLUE/wsc/test.json create mode 100755 tests/ut/data/dataset/testCLUE/wsc/train.json create mode 100644 tests/ut/python/dataset/test_datasets_clue.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index c7fb9955a9..a596d339ec 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -31,6 +31,7 @@ #include "dataset/engine/datasetops/source/celeba_op.h" #include "dataset/engine/datasetops/source/random_data_op.h" #include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/engine/datasetops/source/clue_op.h" #include "dataset/engine/datasetops/filter_op.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_distributed_sample.h" @@ -72,7 +73,8 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kCelebA, &DEPipeline::ParseCelebAOp}, {kRandomData, &DEPipeline::ParseRandomDataOp}, {kTextFile, &DEPipeline::ParseTextFileOp}, - {kBuildVocab, &DEPipeline::ParseBuildVocabOp}}; + {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, + {kClue, &DEPipeline::ParseClueOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { try { @@ -1210,6 +1212,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr(value)) { if (!p.second.is_none()) { @@ -1236,6 +1239,7 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { } return Status::OK(); } + Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *ptr) { std::shared_ptr builder = std::make_shared(); for (auto arg : args) { @@ -1267,5 +1271,45 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + (void)builder->SetClueFilesList(ToStringVector(args["dataset_files"])); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + (void)builder->SetNumDevices(ToInt(value)); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "cols_to_keyword") { + std::map map_dict; + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + map_dict.insert({ToString(p.first), ToString(p.second)}); + } else { + map_dict.insert({ToString(p.first), ToString(p.first)}); + } + } + (void)builder->SetColsKeyMap(map_dict); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 40133fc7b7..f856b3b2ca 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -64,7 +64,8 @@ enum OpName { kCelebA, kRandomData, kTextFile, - kBuildVocab + kBuildVocab, + kClue }; // The C++ binder class that we expose to the python script. @@ -166,6 +167,8 @@ class DEPipeline { Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseClueOp(const py::dict &args, std::shared_ptr *ptr); + private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 5c574844b9..d2a69f2f7f 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -55,6 +55,7 @@ #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/jagged_connector.h" #include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/engine/datasetops/source/clue_op.h" #include "dataset/engine/datasetops/source/voc_op.h" #include "dataset/engine/datasetops/source/coco_op.h" #include "dataset/engine/gnn/graph.h" @@ -201,6 +202,18 @@ void bindDatasetOps(py::module *m) { THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); return count; }); + + (void)py::class_>(*m, "ClueOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); + } + THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); + return count; + }); + (void)py::class_>(*m, "VOCOp") .def_static("get_num_rows", [](const std::string &dir, const std::string &task_type, const std::string &task_mode, @@ -629,7 +642,8 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("RANDOMDATA", OpName::kRandomData) .value("BUILDVOCAB", OpName::kBuildVocab) .value("CELEBA", OpName::kCelebA) - .value("TEXTFILE", OpName::kTextFile); + .value("TEXTFILE", OpName::kTextFile) + .value("CLUE", OpName::kClue); (void)py::enum_(m, "JiebaMode", py::arithmetic()) .value("DE_JIEBA_MIX", JiebaMode::kMix) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt index a4370a9a48..42269735f5 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt @@ -19,4 +19,5 @@ add_library(engine-datasetops-source OBJECT random_data_op.cc celeba_op.cc text_file_op.cc + clue_op.cc ) \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc new file mode 100644 index 0000000000..0081461d88 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc @@ -0,0 +1,551 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/source/clue_op.h" + +#include +#include +#include +#include +#include + +#include "dataset/core/config_manager.h" +#include "dataset/util/task_manager.h" +#include "dataset/engine/jagged_connector.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/util/random.h" + +namespace mindspore { +namespace dataset { + +ClueOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status ClueOp::Builder::ValidateInputs() const { + std::string err; + err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; + err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); +} + +Status ClueOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_clue_files_list_.size()) { + builder_num_workers_ = builder_clue_files_list_.size(); + MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + ColKeyMap ck_map; + for (auto &p : builder_cols_to_keyword_) { + ck_map.insert({p.first, split(p.second, '/')}); + } + + std::shared_ptr clue_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, + builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, + builder_device_id_); + RETURN_IF_NOT_OK(clue_op->Init()); + *op = std::move(clue_op); + + return Status::OK(); +} + +std::vector ClueOp::Builder::split(const std::string &s, char delim) { + std::vector res; + std::stringstream ss(s); + std::string item; + + while (getline(ss, item, delim)) { + res.push_back(item); + } + return res; +} + +ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + rows_per_buffer_(rows_per_buffer), + num_rows_per_shard_(0), + all_num_rows_(0), + num_samples_(num_samples), + filename_index_(std::make_unique()), + clue_files_list_(std::move(clue_files_list)), + load_jagged_connector_(true), + cols_to_keyword_(cols_to_keyword), + shuffle_files_(shuffle_files), + finished_reading_dataset_(false), + num_devices_(num_device), + device_id_(device_id), + load_io_block_queue_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status ClueOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(clue_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + // Set the column name mapping (base class field) + int count = 0; + for (auto &p : cols_to_keyword_) { + column_name_id_map_[p.first] = count; + count++; + } + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + + return Status::OK(); +} + +Status ClueOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status ClueOp::GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t) { + nlohmann::json cursor = js; + for (int i = 0; i < key_chain.size(); i++) { + if (cursor.find(key_chain[i]) != cursor.end()) { + cursor = cursor[key_chain[i]]; + } else { + RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]); + } + } + std::string final_str = key_chain.back(); + switch (cursor.type()) { + case nlohmann::detail::value_t::string: + RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get()}, TensorShape::CreateScalar())); + break; + + case nlohmann::detail::value_t::number_integer: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::number_unsigned: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::number_float: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::array: + RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get>()}, TensorShape::CreateScalar())); + break; + default: + break; + } + return Status::OK(); +} + +Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + if (line.empty()) { + continue; + } + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + try { + nlohmann::json js = nlohmann::json::parse(line); + int cols_count = cols_to_keyword_.size(); + TensorRow tRow(cols_count, nullptr); + tensor_table->push_back(std::move(tRow)); + + int cout = 0; + for (auto &p : cols_to_keyword_) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor)); + (*tensor_table)[rows_each_buffer][cout] = std::move(tensor); + cout++; + } + } catch (const std::exception &err) { + // Catch any exception and convert to Status return code + RETURN_STATUS_UNEXPECTED("Failed to load json file"); + } + + // RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + return Status::OK(); +} + +Status ClueOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this))); + + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + NotifyToFillIOBlockQueue(); + + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == 0 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + return Status::OK(); +} + +Status ClueOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// A print method typically used for debugging +void ClueOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ + << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; + for (int i = 0; i < clue_files_list_.size(); ++i) { + out << " " << clue_files_list_[i]; + } + out << "\n\n"; + } +} + +// Pops an element from a queue in io_block_queues +Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +Status ClueOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +Status ClueOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + auto file_it = filename_index_->Search(*it); + file_index.emplace_back(std::pair(file_it.value(), *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status ClueOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +Status ClueOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API " + "validation first."); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +int64_t ClueOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + if (!line.empty()) { + count++; + } + } + + return count; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status ClueOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +Status ClueOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h new file mode 100644 index 0000000000..1b8f23c97b --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h @@ -0,0 +1,270 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "dataset/util/auto_index.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/datasetops/source/io_block.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; +using ColKeyMap = std::map>; + +class JaggedConnector; + +class ClueOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetClueFilesList(const std::vector &files_list) { + builder_clue_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColsKeyMap(const std::map &cols_to_key) { + builder_cols_to_keyword_ = cols_to_key; + return *this; + } + + // Split string based on a character delimiter + // @return - the a string vector + std::vector split(const std::string &s, char delim); + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_clue_files_list_; + bool builder_shuffle_files_; + std::map builder_cols_to_keyword_; + }; + + // Constructor of ClueOp + ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~ClueOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all clue files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a clue file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - clue file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // @return Status - the error code returned. + Status GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t); + + int32_t device_id_; + bool shuffle_files_; + bool finished_reading_dataset_; + int32_t num_devices_; + int64_t rows_per_buffer_; + bool load_io_block_queue_; + int64_t num_rows_per_shard_; + int64_t all_num_rows_; + int64_t num_samples_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + std::vector clue_files_list_; + WaitPost io_block_queue_wait_post_; + std::unique_ptr jagged_buffer_connector_; + QueueList> io_block_queues_; + bool load_jagged_connector_; + ColKeyMap cols_to_keyword_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc index 40ffe7e9ab..d31495a09b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -43,7 +43,7 @@ TextFileOp::Builder::Builder() Status TextFileOp::Builder::ValidateInputs() const { std::string err_msg; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); } diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 4016211236..f0070b428d 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -21,7 +21,7 @@ can also create samplers with this module to sample data. from .core.configuration import config from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ - TextFileDataset, Schema, Shuffle, zip, RandomDataset + TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show @@ -29,6 +29,6 @@ from .engine.graphdata import GraphData __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", - "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", - "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", - "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] + "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", + "CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", + "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 044b639dfa..cc3686eb00 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -30,7 +30,7 @@ from ..core.configuration import config, ConfigurationManager __all__ = ["config", "ConfigurationManager", "zip", "ImageFolderDatasetV2", "MnistDataset", - "MindDataset", "GeneratorDataset", "TFRecordDataset", + "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema", "DistributedSampler", "PKSampler", diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 270349984a..731bc193fe 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -33,7 +33,7 @@ import copy import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ - MindRecordOp, TextFileOp, VOCOp, CocoOp, CBatchInfo + MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo from mindspore._c_expression import typing from mindspore import log as logger @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_split + check_split, check_cluedataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -4317,6 +4317,222 @@ class CelebADataset(MappableDataset): return self.sampler.is_sharded() +class CLUEDataset(SourceDataset): + """ + A source dataset that reads and parses CLUE datasets. + CLUE, the Chinese Language Understanding Evaluation Benchmark, a collection of datasets, baselines, pre-trained + models, corpus and leaderboard. Here we bring in classification task of CLUE, which are AFQMC, TNEWS, IFLYTEK, + CMNLI, WSC and CSL. + + Args: + dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of + files. The list will be sorted in a lexicographical order. + task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. + (default=AFQMC). + usage (str, optional): Need train, test or eval data (default="train"). + num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_parallel_workers (int, optional): number of workers to read the data + (default=None, number set in the config). + shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. + + Examples: + >>> import mindspore.dataset as ds + >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files + >>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train') + + """ + + @check_cluedataset + def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, + num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + super().__init__(num_parallel_workers) + self.dataset_files = self._find_files(dataset_files) + self.dataset_files.sort() + self.num_samples = num_samples + self.task_dict = { + 'AFQMC': { + 'train': { + 'sentence1': 'sentence1', + 'sentence2': 'sentence2', + 'label': 'label' + }, + 'test': { + 'id': 'id', + 'sentence1': 'sentence1', + 'sentence2': 'sentence2' + }, + 'eval': { + 'sentence1': 'sentence1', + 'sentence2': 'sentence2', + 'label': 'label' + } + }, + 'CMNLI': { + 'train': { + 'sentence1': 'sentence1', + 'sentence2': 'sentence2', + 'label': 'label' + }, + 'test': { + 'id': 'id', + 'sentence1': 'sentence1', + 'sentence2': 'sentence2' + }, + 'eval': { + 'sentence1': 'sentence1', + 'sentence2': 'sentence2', + 'label': 'label' + } + }, + 'CSL': { + 'train': { + 'id': 'id', + 'abst': 'abst', + 'keyword': 'keyword', + 'label': 'label' + }, + 'test': { + 'id': 'id', + 'abst': 'abst', + 'keyword': 'keyword' + }, + 'eval': { + 'id': 'id', + 'abst': 'abst', + 'keyword': 'keyword', + 'label': 'label' + } + }, + 'IFLYTEK': { + 'train': { + 'label': 'label', + 'label_des': 'label_des', + 'sentence': 'sentence' + }, + 'test': { + 'id': 'id', + 'sentence': 'sentence', + }, + 'eval': { + 'label': 'label', + 'label_des': 'label_des', + 'sentence': 'sentence' + } + }, + 'TNEWS': { + 'train': { + 'label': 'label', + 'label_desc': 'label_desc', + 'sentence': 'sentence', + 'keywords': 'keywords' + }, + 'test': { + 'id': 'id', + 'sentence': 'sentence', + 'keywords': 'keywords' + }, + 'eval': { + 'label': 'label', + 'label_desc': 'label_desc', + 'sentence': 'sentence', + 'keywords': 'keywords' + } + }, + 'WSC': { + 'train': { + 'span1_index': 'target/span1_index', + 'span2_index': 'target/span2_index', + 'span1_text': 'target/span1_text', + 'span2_text': 'target/span2_text', + 'idx': 'idx', + 'label': 'label', + 'text': 'text' + }, + 'test': { + 'span1_index': 'target/span1_index', + 'span2_index': 'target/span2_index', + 'span1_text': 'target/span1_text', + 'span2_text': 'target/span2_text', + 'idx': 'idx', + 'text': 'text' + }, + 'eval': { + 'span1_index': 'target/span1_index', + 'span2_index': 'target/span2_index', + 'span1_text': 'target/span1_text', + 'span2_text': 'target/span2_text', + 'idx': 'idx', + 'label': 'label', + 'text': 'text' + } + } + } + self.cols_to_keyword = self.task_dict[task][usage] + + if not isinstance(shuffle, (bool, Shuffle)): + raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") + if not isinstance(shuffle, Shuffle): + if shuffle: + self.shuffle_level = Shuffle.GLOBAL + self.shuffle_files = True + else: + self.shuffle_level = None + self.shuffle_files = False + else: + self.shuffle_level = shuffle + self.shuffle_files = True + + self.num_shards = num_shards + self.shard_id = shard_id + + def get_args(self): + args = super().get_args() + args["dataset_files"] = self.dataset_files + args["num_samples"] = self.num_samples + if self.shuffle_files is not None: + args["shuffle_files"] = self.shuffle_files + args["shuffle"] = self.shuffle_level + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id + args["cols_to_keyword"] = self.cols_to_keyword + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + if self._dataset_size is None: + num_rows = ClueOp.get_num_rows(self.dataset_files) + num_rows = get_num_rows(num_rows, self.num_shards) + if self.num_samples is None: + return num_rows + return min(self.num_samples, num_rows) + return self._dataset_size + + def is_shuffled(self): + return self.shuffle_files + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return False + + class TextFileDataset(SourceDataset): """ A source dataset that reads and parses datasets stored on disk in text format. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index d621f76256..11b082b0e0 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -50,7 +50,8 @@ def alter_tree(node): def _alter_node(node): """Performing some alteration to a dataset node. A common alteration is to insert a node.""" - if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL: + if isinstance(node, (de.TFRecordDataset, de.TextFileDataset, de.CLUEDataset)) \ + and node.shuffle_level == de.Shuffle.GLOBAL: # Remove the connection between the parent's node to the current node because we are inserting a node. if node.output: node.output.pop() @@ -179,6 +180,8 @@ class Iterator: op_type = OpName.TEXTFILE elif isinstance(dataset, de.BuildVocabDataset): op_type = OpName.BUILDVOCAB + elif isinstance(dataset, de.CLUEDataset): + op_type = OpName.CLUE else: raise ValueError("Unsupported DatasetOp") diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ff434c718e..b5ffbbdfc0 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1075,6 +1075,41 @@ def check_add_column(method): return new_method +def check_cluedataset(method): + """A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset).""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + + # check dataset_files; required argument + dataset_files = param_dict.get('dataset_files') + if dataset_files is None: + raise ValueError("dataset_files is not provided.") + if not isinstance(dataset_files, (str, list)): + raise TypeError("dataset_files should be of type str or a list of strings.") + + # check task + task_param = param_dict.get('task') + if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']: + raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL") + + # check usage + usage_param = param_dict.get('usage') + if usage_param not in ['train', 'test', 'eval']: + raise ValueError("usage should be train, test or eval") + + check_param_type(nreq_param_int, param_dict, int) + + check_sampler_shuffle_shard_options(param_dict) + + return method(*args, **kwargs) + + return new_method + + def check_textfiledataset(method): """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 1691aa3de5..417208102a 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -65,6 +65,7 @@ SET(DE_UT_SRCS cifar_op_test.cc celeba_op_test.cc take_op_test.cc + clue_op_test.cc text_file_op_test.cc filter_op_test.cc concat_op_test.cc diff --git a/tests/ut/cpp/dataset/clue_op_test.cc b/tests/ut/cpp/dataset/clue_op_test.cc new file mode 100644 index 0000000000..ff2f01a9ff --- /dev/null +++ b/tests/ut/cpp/dataset/clue_op_test.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "dataset/core/client.h" +#include "common/common.h" +#include "common/utils.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "dataset/engine/datasetops/source/clue_op.h" +#include "dataset/util/status.h" + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestCLUEOp : public UT::DatasetOpTesting { + +}; + +TEST_F(MindDataTestCLUEOp, TestCLUEBasic) { + // Start with an empty execution tree + auto tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testCLUE/afqmc/train.json"; + std::map key_map; + key_map["sentence1"] = "sentence1"; + key_map["sentence2"] = "sentence2"; + key_map["label"] = "label"; + + std::shared_ptr op; + ClueOp::Builder builder; + builder.SetClueFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16) + .SetOpConnectorSize(2) + .SetColsKeyMap(key_map); + + Status rc = builder.Build(&op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(op); + ASSERT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration."; + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + } + + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + row_count++; + } + + ASSERT_EQ(row_count, 3); +} + +TEST_F(MindDataTestCLUEOp, TestTotalRows) { + std::string tf_file1 = datasets_root_path_ + "/testCLUE/afqmc/train.json"; + std::string tf_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json"; + std::vector files; + files.push_back(tf_file1); + int64_t total_rows = 0; + ClueOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 3); + files.clear(); + + files.push_back(tf_file2); + ClueOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 3); + files.clear(); + + files.push_back(tf_file1); + files.push_back(tf_file2); + ClueOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 6); + files.clear(); +} diff --git a/tests/ut/data/dataset/testCLUE/afqmc/dev.json b/tests/ut/data/dataset/testCLUE/afqmc/dev.json new file mode 100644 index 0000000000..4c3d942e2d --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/afqmc/dev.json @@ -0,0 +1,3 @@ +{"sentence1": "你有花呗吗", "sentence2": "我的花呗没额度了", "label": "0"} +{"sentence1": "吃饭能用花呗吗", "sentence2": "花呗太方便了", "label": "0"} +{"sentence1": "蚂蚁花呗支付金额有什么限制", "sentence2": "我到实体店消费用花呗支付受金额限制", "label": "1"} diff --git a/tests/ut/data/dataset/testCLUE/afqmc/test.json b/tests/ut/data/dataset/testCLUE/afqmc/test.json new file mode 100644 index 0000000000..a7d63132d8 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/afqmc/test.json @@ -0,0 +1,3 @@ +{"id": 0, "sentence1": "借呗取消的时间", "sentence2": "蚂蚁借呗恢复的月数"} +{"id": 1, "sentence1": "网商贷用什么方法转变成借呗", "sentence2": "什么手段能将网商贷切换为借呗"} +{"id": 2, "sentence1": "我的借呗为什么开通不了", "sentence2": "我为啥没法开通借呗"} diff --git a/tests/ut/data/dataset/testCLUE/afqmc/train.json b/tests/ut/data/dataset/testCLUE/afqmc/train.json new file mode 100644 index 0000000000..f69c29adcf --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/afqmc/train.json @@ -0,0 +1,3 @@ +{"sentence1": "蚂蚁借呗等额还款能否换成先息后本", "sentence2": "借呗可以先息到期还本吗", "label": "0"} +{"sentence1": "蚂蚁花呗说我违约了", "sentence2": "蚂蚁花呗违约行为是啥", "label": "0"} +{"sentence1": "帮我看看本月花呗账单结清了没", "sentence2": "上月的花呗账单", "label": "0"} diff --git a/tests/ut/data/dataset/testCLUE/cmnli/dev.json b/tests/ut/data/dataset/testCLUE/cmnli/dev.json new file mode 100644 index 0000000000..09449683a9 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/cmnli/dev.json @@ -0,0 +1,3 @@ +{"sentence1": "每个人都有权利", "sentence2": "每个人都有福利", "label": "neutral"} +{"sentence1": "有时候我喜欢他,但我也喜欢看到有人打他", "sentence2": "说实话,我有点喜欢他,但还是喜欢看到有人打他。", "label": "entailment"} +{"sentence1": "我最喜欢的餐馆是离你最近的一家", "sentence2": "我最喜欢的餐馆离你家至少一百英里远。", "label": "contradiction"} diff --git a/tests/ut/data/dataset/testCLUE/cmnli/test.json b/tests/ut/data/dataset/testCLUE/cmnli/test.json new file mode 100644 index 0000000000..ab249f6d24 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/cmnli/test.json @@ -0,0 +1,3 @@ +{"id": 0, "sentence1": "今天,全球都在看着最新航天飞机的处女航。", "sentence2": "全世界都在看最新的航天飞机发射。"} +{"id": 1, "sentence1": "而我们把竹篮放在一个地方,把玻璃瓶放在另一处,把书放在另一处,满了要把它放到车里", "sentence2": "我们没有分开任何东西,都把它全扔进一个箱子里。"} +{"id": 2, "sentence1": "她占用了我的很多时间,她给我读了很多关于灵异的故事,我觉得很无聊。", "sentence2": "我喜欢和她一起读鬼故事。"} diff --git a/tests/ut/data/dataset/testCLUE/cmnli/train.json b/tests/ut/data/dataset/testCLUE/cmnli/train.json new file mode 100644 index 0000000000..705cc46438 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/cmnli/train.json @@ -0,0 +1,3 @@ +{"sentence1": "你应该给这件衣服定一个价格。", "sentence2": "不同的衣服有不同的价格。", "label": "neutral"} +{"sentence1": "我怎么知道他要说什么", "sentence2": "他说什么我并不知道。", "label": "entailment"} +{"sentence1": "向左。", "sentence2": "向右。", "label": "contradiction"} diff --git a/tests/ut/data/dataset/testCLUE/csl/dev.json b/tests/ut/data/dataset/testCLUE/csl/dev.json new file mode 100644 index 0000000000..d43621bdca --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/csl/dev.json @@ -0,0 +1,3 @@ +{"id": 1, "abst": "这是第一段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"} +{"id": 2, "abst": "这是第二段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"} +{"id": 3, "abst": "这是第三段很长的文本", "keyword": ["1", "2", "3"], "label": "0"} diff --git a/tests/ut/data/dataset/testCLUE/csl/test.json b/tests/ut/data/dataset/testCLUE/csl/test.json new file mode 100644 index 0000000000..9459fb3e09 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/csl/test.json @@ -0,0 +1,3 @@ +{"id": 2415, "abst": "长文本1", "keyword": ["关键词1", "关键词2"]} +{"id": 2565, "abst": "长文本2", "keyword": ["关键词1", "关键词2", "关键词3"]} +{"id": 2625, "abst": "长文本3", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"]} diff --git a/tests/ut/data/dataset/testCLUE/csl/train.json b/tests/ut/data/dataset/testCLUE/csl/train.json new file mode 100644 index 0000000000..8e16f5b774 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/csl/train.json @@ -0,0 +1,3 @@ +{"id": 1, "abst": "这是一段长文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "0"} +{"id": 2, "abst": "这是一段长文本", "keyword": ["关键词5", "关键词6", "关键词7", "关键词8"], "label": "0"} +{"id": 3, "abst": "这是一段长文本", "keyword": ["关键词9", "关键词10", "关键词11", "关键词12"], "label": "0"} diff --git a/tests/ut/data/dataset/testCLUE/iflytek/dev.json b/tests/ut/data/dataset/testCLUE/iflytek/dev.json new file mode 100644 index 0000000000..95c8069a8a --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/iflytek/dev.json @@ -0,0 +1,3 @@ +{"label": "110", "label_des": "社区超市", "sentence": "这是第一段文本"} +{"label": "70", "label_des": "工具", "sentence": "这是第二段文本"} +{"label": "10", "label_des": "社区服务", "sentence": "这是第三段文本"} diff --git a/tests/ut/data/dataset/testCLUE/iflytek/test.json b/tests/ut/data/dataset/testCLUE/iflytek/test.json new file mode 100644 index 0000000000..a7bf2bad7a --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/iflytek/test.json @@ -0,0 +1,3 @@ +{"id": 0, "sentence": "文本1"} +{"id": 1, "sentence": "文本2"} +{"id": 2, "sentence": "文本3"} diff --git a/tests/ut/data/dataset/testCLUE/iflytek/train.json b/tests/ut/data/dataset/testCLUE/iflytek/train.json new file mode 100644 index 0000000000..786749bcb6 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/iflytek/train.json @@ -0,0 +1,3 @@ +{"label": "11", "label_des": "薅羊毛", "sentence": "第一个文本"} +{"label": "95", "label_des": "借贷", "sentence": "第二个文本"} +{"label": "74", "label_des": "违章", "sentence": "第三个文本"} diff --git a/tests/ut/data/dataset/testCLUE/tnews/dev.json b/tests/ut/data/dataset/testCLUE/tnews/dev.json new file mode 100644 index 0000000000..0363cee745 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/tnews/dev.json @@ -0,0 +1,3 @@ +{"label": "102", "label_desc": "news_entertainment", "sentence": "新闻1", "keywords": "关键词一,关键词二,关键词三,关键词四"} +{"label": "110", "label_desc": "news_military", "sentence": "新闻2", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"} +{"label": "104", "label_desc": "news_finance", "sentence": "新闻3", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"} diff --git a/tests/ut/data/dataset/testCLUE/tnews/test.json b/tests/ut/data/dataset/testCLUE/tnews/test.json new file mode 100644 index 0000000000..39e36d91e2 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/tnews/test.json @@ -0,0 +1,3 @@ +{"id": 0, "sentence": "新闻1", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5"} +{"id": 1, "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4"} +{"id": 2, "sentence": "新闻3", "keywords": ""} diff --git a/tests/ut/data/dataset/testCLUE/tnews/train.json b/tests/ut/data/dataset/testCLUE/tnews/train.json new file mode 100644 index 0000000000..de784f1f82 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/tnews/train.json @@ -0,0 +1,3 @@ +{"label": "108", "label_desc": "news_edu", "sentence": "新闻1", "keywords": ""} +{"label": "104", "label_desc": "news_finance", "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5,关键词6"} +{"label": "106", "label_desc": "news_house", "sentence": "新闻3", "keywords": ""} diff --git a/tests/ut/data/dataset/testCLUE/wsc/dev.json b/tests/ut/data/dataset/testCLUE/wsc/dev.json new file mode 100755 index 0000000000..57203a7fc9 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/wsc/dev.json @@ -0,0 +1,3 @@ +{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"} +{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"} +{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"} diff --git a/tests/ut/data/dataset/testCLUE/wsc/test.json b/tests/ut/data/dataset/testCLUE/wsc/test.json new file mode 100755 index 0000000000..c8e17d3e4c --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/wsc/test.json @@ -0,0 +1,3 @@ +{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?"} +{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场"} +{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业"} diff --git a/tests/ut/data/dataset/testCLUE/wsc/train.json b/tests/ut/data/dataset/testCLUE/wsc/train.json new file mode 100755 index 0000000000..57203a7fc9 --- /dev/null +++ b/tests/ut/data/dataset/testCLUE/wsc/train.json @@ -0,0 +1,3 @@ +{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"} +{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"} +{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"} diff --git a/tests/ut/python/dataset/test_datasets_clue.py b/tests/ut/python/dataset/test_datasets_clue.py new file mode 100644 index 0000000000..c49db45abe --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_clue.py @@ -0,0 +1,355 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.dataset as ds + + +def test_clue(): + """ + Test CLUE with repeat, skip and so on + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) + data = data.repeat(2) + data = data.skip(3) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'sentence1': d['sentence1'].item().decode("utf8"), + 'sentence2': d['sentence2'].item().decode("utf8") + }) + assert len(buffer) == 3 + + +def test_clue_num_shards(): + """ + Test num_shards param of CLUE dataset + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'sentence1': d['sentence1'].item().decode("utf8"), + 'sentence2': d['sentence2'].item().decode("utf8") + }) + assert len(buffer) == 1 + + +def test_clue_num_samples(): + """ + Test num_samples param of CLUE dataset + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + + data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2) + count = 0 + for _ in data.create_dict_iterator(): + count += 1 + assert count == 2 + + +def test_textline_dataset_get_datasetsize(): + """ + Test get_dataset_size of CLUE dataset + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + + data = ds.TextFileDataset(TRAIN_FILE) + size = data.get_dataset_size() + assert size == 3 + + +def test_clue_afqmc(): + """ + Test AFQMC for train, test and evaluation + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json' + EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json' + + # train + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'sentence1': d['sentence1'].item().decode("utf8"), + 'sentence2': d['sentence2'].item().decode("utf8") + }) + assert len(buffer) == 3 + + # test + buffer = [] + data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'id': d['id'], + 'sentence1': d['sentence1'].item().decode("utf8"), + 'sentence2': d['sentence2'].item().decode("utf8") + }) + assert len(buffer) == 3 + + # evaluation + buffer = [] + data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'sentence1': d['sentence1'].item().decode("utf8"), + 'sentence2': d['sentence2'].item().decode("utf8") + }) + assert len(buffer) == 3 + + +def test_clue_cmnli(): + """ + Test CMNLI for train, test and evaluation + """ + TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json' + TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json' + EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json' + + # train + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'sentence1': d['sentence1'].item().decode("utf8"), + 'sentence2': d['sentence2'].item().decode("utf8") + }) + assert len(buffer) == 3 + + # test + buffer = [] + data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'id': d['id'], + 'sentence1': d['sentence1'], + 'sentence2': d['sentence2'] + }) + assert len(buffer) == 3 + + # eval + buffer = [] + data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'], + 'sentence1': d['sentence1'], + 'sentence2': d['sentence2'] + }) + assert len(buffer) == 3 + + +def test_clue_csl(): + """ + Test CSL for train, test and evaluation + """ + TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json' + TEST_FILE = '../data/dataset/testCLUE/csl/test.json' + EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json' + + # train + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'id': d['id'], + 'abst': d['abst'].item().decode("utf8"), + 'keyword': [i.item().decode("utf8") for i in d['keyword']], + 'label': d['label'].item().decode("utf8") + }) + assert len(buffer) == 3 + + # test + buffer = [] + data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'id': d['id'], + 'abst': d['abst'].item().decode("utf8"), + 'keyword': [i.item().decode("utf8") for i in d['keyword']], + }) + assert len(buffer) == 3 + + # eval + buffer = [] + data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'id': d['id'], + 'abst': d['abst'].item().decode("utf8"), + 'keyword': [i.item().decode("utf8") for i in d['keyword']], + 'label': d['label'].item().decode("utf8") + }) + assert len(buffer) == 3 + + +def test_clue_iflytek(): + """ + Test IFLYTEK for train, test and evaluation + """ + TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json' + TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json' + EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json' + + # train + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'label_des': d['label_des'].item().decode("utf8"), + 'sentence': d['sentence'].item().decode("utf8"), + }) + assert len(buffer) == 3 + + # test + buffer = [] + data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'id': d['id'], + 'sentence': d['sentence'].item().decode("utf8") + }) + assert len(buffer) == 3 + + # eval + buffer = [] + data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'label_des': d['label_des'].item().decode("utf8"), + 'sentence': d['sentence'].item().decode("utf8") + }) + assert len(buffer) == 3 + + +def test_clue_tnews(): + """ + Test TNEWS for train, test and evaluation + """ + TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json' + TEST_FILE = '../data/dataset/testCLUE/tnews/test.json' + EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json' + + # train + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'label_desc': d['label_desc'].item().decode("utf8"), + 'sentence': d['sentence'].item().decode("utf8"), + 'keywords': + d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] + }) + assert len(buffer) == 3 + + # test + buffer = [] + data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'id': d['id'], + 'sentence': d['sentence'].item().decode("utf8"), + 'keywords': + d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] + }) + assert len(buffer) == 3 + + # eval + buffer = [] + data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False) + for d in data.create_dict_iterator(): + buffer.append({ + 'label': d['label'].item().decode("utf8"), + 'label_desc': d['label_desc'].item().decode("utf8"), + 'sentence': d['sentence'].item().decode("utf8"), + 'keywords': + d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] + }) + assert len(buffer) == 3 + + +def test_clue_wsc(): + """ + Test WSC for train, test and evaluation + """ + TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json' + TEST_FILE = '../data/dataset/testCLUE/wsc/test.json' + EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json' + + # train + buffer = [] + data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train') + for d in data.create_dict_iterator(): + buffer.append({ + 'span1_index': d['span1_index'], + 'span2_index': d['span2_index'], + 'span1_text': d['span1_text'].item().decode("utf8"), + 'span2_text': d['span2_text'].item().decode("utf8"), + 'idx': d['idx'], + 'label': d['label'].item().decode("utf8"), + 'text': d['text'].item().decode("utf8") + }) + assert len(buffer) == 3 + + # test + buffer = [] + data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test') + for d in data.create_dict_iterator(): + buffer.append({ + 'span1_index': d['span1_index'], + 'span2_index': d['span2_index'], + 'span1_text': d['span1_text'].item().decode("utf8"), + 'span2_text': d['span2_text'].item().decode("utf8"), + 'idx': d['idx'], + 'text': d['text'].item().decode("utf8") + }) + assert len(buffer) == 3 + + # eval + buffer = [] + data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval') + for d in data.create_dict_iterator(): + buffer.append({ + 'span1_index': d['span1_index'], + 'span2_index': d['span2_index'], + 'span1_text': d['span1_text'].item().decode("utf8"), + 'span2_text': d['span2_text'].item().decode("utf8"), + 'idx': d['idx'], + 'label': d['label'].item().decode("utf8"), + 'text': d['text'].item().decode("utf8") + }) + assert len(buffer) == 3 + + +if __name__ == "__main__": + test_clue() + test_clue_afqmc() + test_clue_cmnli() + test_clue_csl() + test_clue_iflytek() + test_clue_tnews() + test_clue_wsc()