!4872 Add C++ API support for TFRecordDataset

Merge pull request !4872 from TinaMengtingZhang/cpp-api-tfrecord-dataset
pull/4872/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2f3684d4fa

@ -32,6 +32,7 @@
#include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h" #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/voc_op.h" #include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif #endif
@ -1508,6 +1509,56 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
return node_ops; return node_ops;
} }
// Validator for TFRecordDataset
bool TFRecordDataset::ValidateParams() { return true; }
// Function to build TFRecordDataset
std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// Sort the datasets file in a lexicographical order
std::vector<std::string> sorted_dir_files = dataset_files_;
std::sort(sorted_dir_files.begin(), sorted_dir_files.end());
// Create Schema Object
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
if (!schema_path_.empty()) {
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaFile(schema_path_, columns_list_));
} else if (schema_obj_ != nullptr) {
std::string schema_json_string = schema_obj_->to_json();
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, columns_list_));
}
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Create and initalize TFReaderOp
std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>(
num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema),
connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_, nullptr);
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op);
}
// Add TFReaderOp
node_ops.push_back(tf_reader_op);
return node_ops;
}
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Constructor for VOCDataset // Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,

@ -32,6 +32,7 @@
#include "minddata/dataset/include/type_id.h" #include "minddata/dataset/include/type_id.h"
#include "minddata/dataset/kernels/c_func_op.h" #include "minddata/dataset/kernels/c_func_op.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/text/vocab.h" #include "minddata/dataset/text/vocab.h"
#endif #endif
@ -69,6 +70,7 @@ class ManifestDataset;
class MnistDataset; class MnistDataset;
class RandomDataset; class RandomDataset;
class TextFileDataset; class TextFileDataset;
class TFRecordDataset;
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
class VOCDataset; class VOCDataset;
#endif #endif
@ -320,6 +322,80 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0); int32_t shard_id = 0);
/// \brief Function to create a TFRecordDataset
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the
/// meta data from the TFData file is considered the schema.)
/// \param[in] columns_list List of columns to be read. (Default = {}, read all columns)
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// If num_samples is 0 and numRows(parsed from schema) does not exist, read the full dataset;
/// If num_samples is 0 and numRows(parsed from schema) is greater than 0, read numRows rows;
/// If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
/// \param[in] shuffle The mode for shuffling data every epoch. (Default = ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be specified only
/// when num_shards is also specified. (Default = 0)
/// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of
/// each shard may be not equal)
/// \return Shared pointer to the current TFRecordDataset
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0, bool shard_equal_rows = false) {
if (dataset_files.empty()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset_files is not specified.";
return nullptr;
}
for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset file: [" << f << "] is invalid or does not exist.";
return nullptr;
}
}
if (num_samples < 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid number of samples: " << num_samples;
return nullptr;
}
if (num_shards <= 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid num_shards: " << num_shards;
return nullptr;
}
if (shard_id < 0 || shard_id >= num_shards) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
return nullptr;
}
std::shared_ptr<TFRecordDataset> ds = nullptr;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
} else {
std::string schema_path = schema;
if (!schema_path.empty()) {
Path schema_file(schema_path);
if (!schema_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: schema path [" << schema_path << "] is invalid or does not exist.";
return nullptr;
}
}
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
}
return ds;
}
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Function to create a VOCDataset /// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns : /// \notes The generated dataset has multi-columns :
@ -952,6 +1028,61 @@ class TextFileDataset : public Dataset {
ShuffleMode shuffle_; ShuffleMode shuffle_;
}; };
/// \class TFRecordDataset
/// \brief A Dataset derived class to represent TFRecord dataset
class TFRecordDataset : public Dataset {
public:
/// \brief Constructor
/// \note Parameter 'schema' is the path to the schema file
TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_path_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Constructor
/// \note Parameter 'schema' is shared pointer to Schema object
TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_obj_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Destructor
~TFRecordDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object.
std::vector<std::string> columns_list_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
bool shard_equal_rows_;
};
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
class VOCDataset : public Dataset { class VOCDataset : public Dataset {
public: public:

@ -3548,7 +3548,7 @@ class TFRecordDataset(SourceDataset):
If the schema is not provided, the meta data from the TFData file is considered the schema. If the schema is not provided, the meta data from the TFData file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns) columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int, optional): number of samples(rows) to read (default=None). num_samples (int, optional): number of samples(rows) to read (default=None).
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset; If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
num_parallel_workers (int, optional): number of workers to read the data num_parallel_workers (int, optional): number of workers to read the data
@ -3567,8 +3567,8 @@ class TFRecordDataset(SourceDataset):
into (default=None). into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified. argument should be specified only when num_shards is also specified.
shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
of rows of each shard may be not equal. is false, number of rows of each shard may be not equal.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended. The cache feature is under development and is not recommended.
Examples: Examples:

@ -107,9 +107,10 @@ SET(DE_UT_SRCS
c_api_dataset_clue_test.cc c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc c_api_dataset_coco_test.cc
c_api_dataset_csv_test.cc c_api_dataset_csv_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_manifest_test.cc c_api_dataset_manifest_test.cc
c_api_dataset_randomdata_test.cc c_api_dataset_randomdata_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_voc_test.cc c_api_dataset_voc_test.cc
c_api_datasets_test.cc c_api_datasets_test.cc
c_api_dataset_iterator_test.cc c_api_dataset_iterator_test.cc

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save