|
|
|
@ -32,6 +32,7 @@
|
|
|
|
|
#include "minddata/dataset/include/type_id.h"
|
|
|
|
|
#include "minddata/dataset/kernels/c_func_op.h"
|
|
|
|
|
#include "minddata/dataset/kernels/tensor_op.h"
|
|
|
|
|
#include "minddata/dataset/util/path.h"
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
#include "minddata/dataset/text/vocab.h"
|
|
|
|
|
#endif
|
|
|
|
@ -69,6 +70,7 @@ class ManifestDataset;
|
|
|
|
|
class MnistDataset;
|
|
|
|
|
class RandomDataset;
|
|
|
|
|
class TextFileDataset;
|
|
|
|
|
class TFRecordDataset;
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
class VOCDataset;
|
|
|
|
|
#endif
|
|
|
|
@ -320,6 +322,80 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
|
|
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
|
|
|
|
int32_t shard_id = 0);
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a TFRecordDataset
|
|
|
|
|
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
|
|
|
|
|
/// will be sorted in a lexicographical order.
|
|
|
|
|
/// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the
|
|
|
|
|
/// meta data from the TFData file is considered the schema.)
|
|
|
|
|
/// \param[in] columns_list List of columns to be read. (Default = {}, read all columns)
|
|
|
|
|
/// \param[in] num_samples The number of samples to be included in the dataset.
|
|
|
|
|
/// (Default = 0 means all samples.)
|
|
|
|
|
/// If num_samples is 0 and numRows(parsed from schema) does not exist, read the full dataset;
|
|
|
|
|
/// If num_samples is 0 and numRows(parsed from schema) is greater than 0, read numRows rows;
|
|
|
|
|
/// If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
|
|
|
|
|
/// \param[in] shuffle The mode for shuffling data every epoch. (Default = ShuffleMode::kGlobal)
|
|
|
|
|
/// Can be any of:
|
|
|
|
|
/// ShuffleMode::kFalse - No shuffling is performed.
|
|
|
|
|
/// ShuffleMode::kFiles - Shuffle files only.
|
|
|
|
|
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
|
|
|
|
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
|
|
|
|
|
/// \param[in] shard_id The shard ID within num_shards. This argument should be specified only
|
|
|
|
|
/// when num_shards is also specified. (Default = 0)
|
|
|
|
|
/// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of
|
|
|
|
|
/// each shard may be not equal)
|
|
|
|
|
/// \return Shared pointer to the current TFRecordDataset
|
|
|
|
|
template <typename T = std::shared_ptr<SchemaObj>>
|
|
|
|
|
std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr,
|
|
|
|
|
const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0,
|
|
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
|
|
|
|
int32_t shard_id = 0, bool shard_equal_rows = false) {
|
|
|
|
|
if (dataset_files.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "TFRecordDataset: dataset_files is not specified.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto f : dataset_files) {
|
|
|
|
|
Path dataset_file(f);
|
|
|
|
|
if (!dataset_file.Exists()) {
|
|
|
|
|
MS_LOG(ERROR) << "TFRecordDataset: dataset file: [" << f << "] is invalid or does not exist.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (num_samples < 0) {
|
|
|
|
|
MS_LOG(ERROR) << "TFRecordDataset: Invalid number of samples: " << num_samples;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (num_shards <= 0) {
|
|
|
|
|
MS_LOG(ERROR) << "TFRecordDataset: Invalid num_shards: " << num_shards;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (shard_id < 0 || shard_id >= num_shards) {
|
|
|
|
|
MS_LOG(ERROR) << "TFRecordDataset: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<TFRecordDataset> ds = nullptr;
|
|
|
|
|
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
|
|
|
|
|
std::shared_ptr<SchemaObj> schema_obj = schema;
|
|
|
|
|
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
|
|
|
|
|
shard_id, shard_equal_rows);
|
|
|
|
|
} else {
|
|
|
|
|
std::string schema_path = schema;
|
|
|
|
|
if (!schema_path.empty()) {
|
|
|
|
|
Path schema_file(schema_path);
|
|
|
|
|
if (!schema_file.Exists()) {
|
|
|
|
|
MS_LOG(ERROR) << "TFRecordDataset: schema path [" << schema_path << "] is invalid or does not exist.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ds = std::make_shared<TFRecordDataset>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards,
|
|
|
|
|
shard_id, shard_equal_rows);
|
|
|
|
|
}
|
|
|
|
|
return ds;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
/// \brief Function to create a VOCDataset
|
|
|
|
|
/// \notes The generated dataset has multi-columns :
|
|
|
|
@ -952,6 +1028,61 @@ class TextFileDataset : public Dataset {
|
|
|
|
|
ShuffleMode shuffle_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// \class TFRecordDataset
|
|
|
|
|
/// \brief A Dataset derived class to represent TFRecord dataset
|
|
|
|
|
class TFRecordDataset : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
/// \note Parameter 'schema' is the path to the schema file
|
|
|
|
|
TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema,
|
|
|
|
|
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
|
|
|
|
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
|
|
|
|
|
: dataset_files_(dataset_files),
|
|
|
|
|
schema_path_(schema),
|
|
|
|
|
columns_list_(columns_list),
|
|
|
|
|
num_samples_(num_samples),
|
|
|
|
|
shuffle_(shuffle),
|
|
|
|
|
num_shards_(num_shards),
|
|
|
|
|
shard_id_(shard_id),
|
|
|
|
|
shard_equal_rows_(shard_equal_rows) {}
|
|
|
|
|
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
/// \note Parameter 'schema' is shared pointer to Schema object
|
|
|
|
|
TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
|
|
|
|
|
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
|
|
|
|
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
|
|
|
|
|
: dataset_files_(dataset_files),
|
|
|
|
|
schema_obj_(schema),
|
|
|
|
|
columns_list_(columns_list),
|
|
|
|
|
num_samples_(num_samples),
|
|
|
|
|
shuffle_(shuffle),
|
|
|
|
|
num_shards_(num_shards),
|
|
|
|
|
shard_id_(shard_id),
|
|
|
|
|
shard_equal_rows_(shard_equal_rows) {}
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
|
~TFRecordDataset() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
|
|
|
|
/// \return The list of shared pointers to the newly created DatasetOps
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
|
|
|
|
|
|
|
|
|
/// \brief Parameters validation
|
|
|
|
|
/// \return bool true if all the params are valid
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<std::string> dataset_files_;
|
|
|
|
|
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
|
|
|
|
|
std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object.
|
|
|
|
|
std::vector<std::string> columns_list_;
|
|
|
|
|
int64_t num_samples_;
|
|
|
|
|
ShuffleMode shuffle_;
|
|
|
|
|
int32_t num_shards_;
|
|
|
|
|
int32_t shard_id_;
|
|
|
|
|
bool shard_equal_rows_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
class VOCDataset : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|