!4696 C++ API Support for CSV Dataset

Merge pull request !4696 from jiangzhiwen/jzw/c_api_csv
pull/4696/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d541e261a0

@ -25,6 +25,7 @@
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
@ -161,6 +162,18 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a CSVDataset.
std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id) {
auto ds = std::make_shared<CSVDataset>(dataset_files, field_delim, column_defaults, column_names, num_samples,
shuffle, num_shards, shard_id);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a ImageFolderDataset.
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
@ -1021,6 +1034,84 @@ std::vector<std::shared_ptr<DatasetOp>> CocoDataset::Build() {
return node_ops;
}
// Constructor for CSVDataset
CSVDataset::CSVDataset(const std::vector<std::string> &csv_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id)
: dataset_files_(csv_files),
field_delim_(field_delim),
column_defaults_(column_defaults),
column_names_(column_names),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {}
bool CSVDataset::ValidateParams() {
if (!ValidateDatasetFilesParam("CSVDataset", dataset_files_)) {
return false;
}
if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') {
MS_LOG(ERROR) << "CSVDataset: The field delimiter should not be \", \\r, \\n";
return false;
}
if (num_samples_ < -1) {
MS_LOG(ERROR) << "CSVDataset: Invalid number of samples: " << num_samples_;
return false;
}
if (!ValidateDatasetShardParams("CSVDataset", num_shards_, shard_id_)) {
return false;
}
return true;
}
// Function to build CSVDataset
std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
for (auto v : column_defaults_) {
if (v->type == CsvType::INT) {
column_default_list.push_back(
std::make_shared<CsvOp::Record<int>>(CsvOp::INT, std::dynamic_pointer_cast<CsvRecord<int>>(v)->value));
} else if (v->type == CsvType::FLOAT) {
column_default_list.push_back(
std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, std::dynamic_pointer_cast<CsvRecord<float>>(v)->value));
} else if (v->type == CsvType::STRING) {
column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>(
CsvOp::STRING, std::dynamic_pointer_cast<CsvRecord<std::string>>(v)->value));
}
}
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
dataset_files_, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, num_samples_,
worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_EMPTY_IF_ERROR(csv_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op);
}
node_ops.push_back(csv_op);
return node_ops;
}
ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
bool recursive, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing)

@ -51,6 +51,8 @@ class Cifar10Dataset;
class Cifar100Dataset;
class CLUEDataset;
class CocoDataset;
class CSVDataset;
class CsvBase;
class ImageFolderDataset;
class ManifestDataset;
class MnistDataset;
@ -114,13 +116,13 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
/// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current CLUEDataset
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
@ -148,6 +150,32 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
const std::string &task = "Detection", const bool &decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a CSVDataset
/// \notes The generated dataset has a variable number of columns
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] field_delim A char that indicates the delimiter to separate fields (default=',').
/// \param[in] column_defaults List of default values for the CSV field (default={}). Each item in the list is
/// either a valid type (float, int, or string). If this is not provided, treats all columns as string type.
/// \param[in] column_names List of column names of the dataset (default={}). If this is not provided, infers the
/// column_names from the first row of CSV file.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = -1 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current Dataset
std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',',
const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {},
const std::vector<std::string> &column_names = {}, int64_t num_samples = -1,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
/// All images within one folder have the same label
@ -217,13 +245,13 @@ std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schem
/// will be sorted in a lexicographical order.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current TextFileDataset
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples = 0,
@ -572,6 +600,57 @@ class CocoDataset : public Dataset {
std::shared_ptr<SamplerObj> sampler_;
};
/// \brief Record type for CSV
enum CsvType : uint8_t { INT = 0, FLOAT, STRING };
/// \brief Base class of CSV Record
struct CsvBase {
public:
CsvBase() = default;
explicit CsvBase(CsvType t) : type(t) {}
virtual ~CsvBase() {}
CsvType type;
};
/// \brief CSV Record that can represent integer, float and string.
template <typename T>
class CsvRecord : public CsvBase {
public:
CsvRecord() = default;
CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {}
~CsvRecord() {}
T value;
};
class CSVDataset : public Dataset {
public:
/// \brief Constructor
CSVDataset(const std::vector<std::string> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names,
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
/// \brief Destructor
~CSVDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::string> dataset_files_;
char field_delim_;
std::vector<std::shared_ptr<CsvBase>> column_defaults_;
std::vector<std::string> column_names_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
};
/// \class ImageFolderDataset
/// \brief A Dataset derived class to represent ImageFolder dataset
class ImageFolderDataset : public Dataset {

@ -103,7 +103,9 @@ SET(DE_UT_SRCS
c_api_dataset_cifar_test.cc
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_filetext_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc

File diff suppressed because it is too large Load Diff

@ -0,0 +1,3 @@
13,14,15,16
17,18,19,20
21,22,23,24
1 13 14 15 16
2 17 18 19 20
3 21 22 23 24

@ -0,0 +1,2 @@
,2,3.0,
a,4,5,b
1 2 3.0
2 a 4 5 b
Loading…
Cancel
Save