add tfrecord dataset to cpp api

fix to support schema=nullptr
pull/4872/head
tinazhang 5 years ago
parent 64ced295c7
commit e430b4056a

@ -32,6 +32,7 @@
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif
@ -1503,6 +1504,56 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
return node_ops;
}
// Validator for TFRecordDataset
bool TFRecordDataset::ValidateParams() { return true; }
// Function to build TFRecordDataset
std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// Sort the datasets file in a lexicographical order
std::vector<std::string> sorted_dir_files = dataset_files_;
std::sort(sorted_dir_files.begin(), sorted_dir_files.end());
// Create Schema Object
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
if (!schema_path_.empty()) {
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaFile(schema_path_, columns_list_));
} else if (schema_obj_ != nullptr) {
std::string schema_json_string = schema_obj_->to_json();
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, columns_list_));
}
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Create and initalize TFReaderOp
std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>(
num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema),
connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_, nullptr);
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op);
}
// Add TFReaderOp
node_ops.push_back(tf_reader_op);
return node_ops;
}
#ifndef ENABLE_ANDROID
// Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,

@ -32,6 +32,7 @@
#include "minddata/dataset/include/type_id.h"
#include "minddata/dataset/kernels/c_func_op.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/path.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/text/vocab.h"
#endif
@ -69,6 +70,7 @@ class ManifestDataset;
class MnistDataset;
class RandomDataset;
class TextFileDataset;
class TFRecordDataset;
#ifndef ENABLE_ANDROID
class VOCDataset;
#endif
@ -320,6 +322,80 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
/// \brief Function to create a TFRecordDataset
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the
/// meta data from the TFData file is considered the schema.)
/// \param[in] columns_list List of columns to be read. (Default = {}, read all columns)
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// If num_samples is 0 and numRows(parsed from schema) does not exist, read the full dataset;
/// If num_samples is 0 and numRows(parsed from schema) is greater than 0, read numRows rows;
/// If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
/// \param[in] shuffle The mode for shuffling data every epoch. (Default = ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be specified only
/// when num_shards is also specified. (Default = 0)
/// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of
/// each shard may be not equal)
/// \return Shared pointer to the current TFRecordDataset
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0, bool shard_equal_rows = false) {
if (dataset_files.empty()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset_files is not specified.";
return nullptr;
}
for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: dataset file: [" << f << "] is invalid or does not exist.";
return nullptr;
}
}
if (num_samples < 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid number of samples: " << num_samples;
return nullptr;
}
if (num_shards <= 0) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid num_shards: " << num_shards;
return nullptr;
}
if (shard_id < 0 || shard_id >= num_shards) {
MS_LOG(ERROR) << "TFRecordDataset: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
return nullptr;
}
std::shared_ptr<TFRecordDataset> ds = nullptr;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
} else {
std::string schema_path = schema;
if (!schema_path.empty()) {
Path schema_file(schema_path);
if (!schema_file.Exists()) {
MS_LOG(ERROR) << "TFRecordDataset: schema path [" << schema_path << "] is invalid or does not exist.";
return nullptr;
}
}
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
}
return ds;
}
#ifndef ENABLE_ANDROID
/// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns :
@ -952,6 +1028,61 @@ class TextFileDataset : public Dataset {
ShuffleMode shuffle_;
};
/// \class TFRecordDataset
/// \brief A Dataset derived class to represent TFRecord dataset
class TFRecordDataset : public Dataset {
public:
/// \brief Constructor
/// \note Parameter 'schema' is the path to the schema file
TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_path_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Constructor
/// \note Parameter 'schema' is shared pointer to Schema object
TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
schema_obj_(schema),
columns_list_(columns_list),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
shard_equal_rows_(shard_equal_rows) {}
/// \brief Destructor
~TFRecordDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object.
std::vector<std::string> columns_list_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
bool shard_equal_rows_;
};
#ifndef ENABLE_ANDROID
class VOCDataset : public Dataset {
public:

@ -3541,7 +3541,7 @@ class TFRecordDataset(SourceDataset):
If the schema is not provided, the meta data from the TFData file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int, optional): number of samples(rows) to read (default=None).
If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset;
If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
num_parallel_workers (int, optional): number of workers to read the data
@ -3560,8 +3560,8 @@ class TFRecordDataset(SourceDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
of rows of each shard may be not equal.
shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
is false, number of rows of each shard may be not equal.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Examples:

@ -107,9 +107,10 @@ SET(DE_UT_SRCS
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_csv_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_randomdata_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_dataset_iterator_test.cc

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save