CSV dataset loader

pull/3016/head
jiangzhiwen 5 years ago
parent 5d42d00161
commit 2f506b7985

@ -31,6 +31,7 @@
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp},
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp},
{kCsv, &DEPipeline::ParseCsvOp},
{kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
@ -1838,6 +1840,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num
return Status::OK();
}
Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
(void)builder->SetCsvFilesList(files_list);
} else {
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
}
// Optional arguments
bool shuffle_required = false;
int64_t num_devices = 0;
std::vector<std::string> col_names;
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
shuffle_required = ToBool(value);
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_shards") {
num_devices = ToInt(value);
(void)builder->SetNumDevices(num_devices);
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "field_delim") {
(void)builder->SetFieldDelim(ToString(value)[0]);
} else if (key == "column_defaults") {
py::list py_object_list = py::reinterpret_borrow<py::list>(value);
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
for (auto l : py_object_list) {
std::string type_s = (std::string)py::str(l.get_type().attr("__name__"));
if (type_s == "int") {
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, ToInt(l)));
} else if (type_s == "float") {
column_default_list.push_back(std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, ToFloat(l)));
} else if (type_s == "str") {
column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ToString(l)));
} else {
RETURN_STATUS_UNEXPECTED("Record type is not allowed");
}
}
(void)builder->SetColumDefault(column_default_list);
} else if (key == "column_names") {
col_names = ToStringVector(value);
(void)builder->SetColumName(col_names);
}
}
}
std::shared_ptr<CsvOp> csv_op;
RETURN_IF_NOT_OK(builder->Build(&csv_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op));
*top = csv_op;
if (shuffle_required) {
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t shuffle_size = 0;
int64_t num_rows = 0;
// First, get the number of rows in the dataset and then compute the shuffle size
RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(files_list, col_names.empty(), &num_rows));
RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size));
// Add the shuffle op over top of this op and return the subtree (top/bottom) to caller
RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, csv_op, &shuffle_op));
*top = shuffle_op;
*bottom = csv_op;
}
return Status::OK();
}
// Helper function to inject a shuffle operator over top of the current operation being built.
Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *shuffle_op) {

@ -73,6 +73,7 @@ enum OpName {
kClue,
kEpochCtrl,
kSentencePieceVocab,
kCsv
};
// The C++ binder class that we expose to the python script.
@ -201,6 +202,8 @@ class DEPipeline {
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;

@ -19,6 +19,7 @@
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) {
return count;
});
(void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp")
.def_static("get_num_rows", [](const py::list &files, bool csv_header) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
}
THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count));
return count;
});
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue)
.value("EPOCHCTRL", OpName::kEpochCtrl);
.value("EPOCHCTRL", OpName::kEpochCtrl)
.value("CSV", OpName::kCsv)
.value("CLUE", OpName::kClue);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix)

@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
celeba_op.cc
text_file_op.cc
clue_op.cc
csv_op.cc
)
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
@ -29,4 +30,4 @@ if (ENABLE_PYTHON)
)
endif()
add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES})
add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES})

@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from .core import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
from .engine.cache_client import DatasetCache
@ -31,5 +31,5 @@ from .engine.graphdata import GraphData
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset",
"CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler",
"CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler",
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]

@ -29,7 +29,7 @@ from .samplers import *
from ..core import config
__all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "CSVDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler",
"PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]

@ -33,7 +33,7 @@ import copy
import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo
MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo
from mindspore._c_expression import typing
from mindspore import log as logger
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
@ -1012,7 +1012,7 @@ class Dataset:
if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id
return "", dev_id
if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)):
if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset, CSVDataset)):
if output_dataset.shard_id is not None:
dev_id = output_dataset.shard_id
return "", dev_id
@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset):
}
Args:
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
files. The list will be sorted in a lexicographical order.
dataset_files (str or a list of strings): String or list of files to be read or glob strings to search for
a pattern of files. The list will be sorted in a lexicographical order.
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
(default=AFQMC).
usage (str, optional): Need train, test or eval data (default="train").
@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset):
return False
class CSVDataset(SourceDataset):
"""
A source dataset that reads and parses CSV datasets.
Args:
dataset_files (str or a list of strings): String or list of files to be read or glob strings to search
for a pattern of files. The list will be sorted in a lexicographical order.
field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
in the list is either a valid type (float, int, or string). If this is not provided, treats all
columns as string type.
column_names (list of string, optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4'])
"""
@check_csvdataset
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.field_delim = field_delim
self.column_defaults = column_defaults
self.column_names = column_names
self.num_samples = num_samples
if not isinstance(shuffle, (bool, Shuffle)):
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
if not isinstance(shuffle, Shuffle):
if shuffle:
self.shuffle_level = Shuffle.GLOBAL
self.shuffle_files = True
else:
self.shuffle_level = None
self.shuffle_files = False
else:
self.shuffle_level = shuffle
self.shuffle_files = True
self.num_shards = num_shards
self.shard_id = shard_id
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
args['field_delim'] = self.field_delim
args['column_defaults'] = self.column_defaults
args['column_names'] = self.column_names
args["num_samples"] = self.num_samples
if self.shuffle_files is not None:
args["shuffle_files"] = self.shuffle_files
args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if self._dataset_size is None:
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
def is_shuffled(self):
return self.shuffle_files
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return False
class TextFileDataset(SourceDataset):
"""
A source dataset that reads and parses datasets stored on disk in text format.

@ -185,6 +185,8 @@ class Iterator:
op_type = OpName.SENTENCEPIECEVOCAB
elif isinstance(dataset, de.CLUEDataset):
op_type = OpName.CLUE
elif isinstance(dataset, de.CSVDataset):
op_type = OpName.CSV
else:
raise ValueError("Unsupported DatasetOp")

@ -787,6 +787,49 @@ def check_cluedataset(method):
return new_method
def check_csvdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_files; required argument
dataset_files = param_dict.get('dataset_files')
type_check(dataset_files, (str, list), "dataset files")
# check field_delim
field_delim = param_dict.get('field_delim')
type_check(field_delim, (str,), 'field delim')
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
raise ValueError("field_delim is not legal.")
# check column_defaults
column_defaults = param_dict.get('column_defaults')
if column_defaults is not None:
if not isinstance(column_defaults, list):
raise TypeError("column_defaults should be type of list.")
for item in column_defaults:
if not isinstance(item, (str, int, float)):
raise TypeError("column type is not legal.")
# check column_names: must be list of string.
column_names = param_dict.get("column_names")
if column_names is not None:
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_textfiledataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""

@ -77,6 +77,7 @@ SET(DE_UT_SRCS
celeba_op_test.cc
take_op_test.cc
clue_op_test.cc
csv_op_test.cc
text_file_op_test.cc
filter_op_test.cc
concat_op_test.cc

@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/util/status.h"
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestCSVOp : public UT::DatasetOpTesting {
};
TEST_F(MindDataTestCSVOp, TestCSVBasic) {
// Start with an empty execution tree
auto tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
std::shared_ptr<CsvOp> op;
CsvOp::Builder builder;
builder.SetCsvFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetNumWorkers(16)
.SetShuffleFiles(false)
.SetOpConnectorSize(2)
.SetFieldDelim(',')
.SetColumDefault(column_default_list)
.SetColumName({"col1", "col2", "col3", "col4"});
Status rc = builder.Build(&op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 3);
}
TEST_F(MindDataTestCSVOp, TestTotalRows) {
std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv";
std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv";
std::vector<std::string> files;
files.push_back(csv_file1);
int64_t total_rows = 0;
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 3);
files.clear();
files.push_back(csv_file2);
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 5);
files.clear();
files.push_back(csv_file1);
files.push_back(csv_file2);
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 8);
files.clear();
}

@ -0,0 +1,3 @@
1,2,3,4
5,6,7,8
9,10,11,12
1 1 2 3 4
2 5 6 7 8
3 9 10 11 12

@ -0,0 +1,8 @@
,"222",3,"4"""
"5",6,,"8"
9,10,"1""1",12
,,"",
,,,
a,b,c,""
a,b,c,d
1 222 3 4"
2 5 6 8
3 9 10 1"1 12
4
5
6 a b c
7 a b c d

@ -0,0 +1 @@
大家,早上好,中午好,下午好,晚上好
1 大家 早上好 中午好 下午好 晚上好

@ -0,0 +1,2 @@
"a,b","c""d","e
f"," g "
1 a,b c"d e f g

@ -0,0 +1,3 @@
1,2,3,4
5,6,7,8
a,"c",d,"e
Can't render this file because it contains an unexpected character in line 3 and column 12.

@ -0,0 +1,2 @@
col1,col2,col3,col4
a,b,c,d
1 col1 col2 col3 col4
2 a b c d

@ -0,0 +1 @@
3,0.3,4,55.5
1 3 0.3 4 55.5

@ -0,0 +1 @@
"a","b","c","d"
1 a b c d

@ -0,0 +1,10 @@
1,2,3,4
"a","b","c
","d
e"
5,6,7,8
9,10,11,12
a,"b
",c,"d
e"
1 1 2 3 4
2 a b c d e
3 5 6 7 8
4 9 10 11 12
5 a b c d e

@ -0,0 +1,238 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import numpy as np
import pytest
DATA_FILE = '../data/dataset/testCSV/1.csv'
def test_csv_dataset_basic():
"""
Test CSV with repeat, skip and so on
"""
TRAIN_FILE = '../data/dataset/testCSV/1.csv'
buffer = []
data = ds.CSVDataset(
TRAIN_FILE,
column_defaults=["0", 0, 0.0, "0"],
column_names=['1', '2', '3', '4'],
shuffle=False)
data = data.repeat(2)
data = data.skip(2)
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 4
def test_csv_dataset_one_file():
data = ds.CSVDataset(
DATA_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 3
def test_csv_dataset_all_file():
APPEND_FILE = '../data/dataset/testCSV/2.csv'
data = ds.CSVDataset(
[DATA_FILE, APPEND_FILE],
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 10
def test_csv_dataset_num_samples():
data = ds.CSVDataset(
DATA_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False, num_samples=2)
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 2
def test_csv_dataset_distribution():
TEST_FILE = '../data/dataset/testCSV/1.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False, num_shards=2, shard_id=0)
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 2
def test_csv_dataset_quoted():
TEST_FILE = '../data/dataset/testCSV/quoted.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_separated():
TEST_FILE = '../data/dataset/testCSV/separated.csv'
data = ds.CSVDataset(
TEST_FILE,
field_delim='|',
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_embedded():
TEST_FILE = '../data/dataset/testCSV/embedded.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a,b', 'c"d', 'e\nf', ' g ']
def test_csv_dataset_chinese():
TEST_FILE = '../data/dataset/testCSV/chinese.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4', 'col5'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8"),
d['col5'].item().decode("utf8")])
assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好']
def test_csv_dataset_header():
TEST_FILE = '../data/dataset/testCSV/header.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_number():
TEST_FILE = '../data/dataset/testCSV/number.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=[0.0, 0.0, 0, 0.0],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item(),
d['col2'].item(),
d['col3'].item(),
d['col4'].item()])
assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
def test_csv_dataset_size():
TEST_FILE = '../data/dataset/testCSV/size.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=[0.0, 0.0, 0, 0.0],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
assert data.get_dataset_size() == 5
def test_csv_dataset_exception():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator():
pass
assert "Failed to parse CSV file" in str(err.value)
def test_csv_dataset_type_error():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", 0, "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator():
pass
assert "invalid argument of stoi" in str(err.value)
if __name__ == "__main__":
test_csv_dataset_basic()
test_csv_dataset_one_file()
test_csv_dataset_all_file()
test_csv_dataset_num_samples()
test_csv_dataset_distribution()
test_csv_dataset_quoted()
test_csv_dataset_separated()
test_csv_dataset_embedded()
test_csv_dataset_chinese()
test_csv_dataset_header()
test_csv_dataset_number()
test_csv_dataset_size()
test_csv_dataset_exception()
test_csv_dataset_type_error()
Loading…
Cancel
Save