add concat op and extend zip support to cpp api

pull/3610/head
tinazhang 5 years ago
parent 6945eb2821
commit e57d849618

@ -27,6 +27,7 @@
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
@ -127,6 +128,14 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
return ds->ValidateParams() ? ds : nullptr;
}
// Function to overload "+" operator to concat two datasets
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
const std::shared_ptr<Dataset> &datasets2) {
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a VOCDataset.
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode,
@ -137,6 +146,14 @@ std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::strin
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a ZipDataset.
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
auto ds = std::make_shared<ZipDataset>(datasets);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
// (In alphabetical order)
@ -157,6 +174,14 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
return ds;
}
// Function to create a Concat dataset
std::shared_ptr<ConcatDataset> Dataset::Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) {
auto ds = std::make_shared<ConcatDataset>(datasets);
ds->children.push_back(shared_from_this());
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a Map dataset.
std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns,
@ -269,16 +294,10 @@ std::shared_ptr<Dataset> Dataset::Take(int32_t count) {
// Function to create a Zip dataset
std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
// Default values
auto ds = std::make_shared<ZipDataset>();
if (!ds->ValidateParams()) {
return nullptr;
}
for (auto dataset : datasets) {
ds->children.push_back(dataset);
}
auto ds = std::make_shared<ZipDataset>(datasets);
ds->children.push_back(shared_from_this());
return ds;
return ds->ValidateParams() ? ds : nullptr;
}
// OTHER FUNCTIONS
@ -526,6 +545,27 @@ bool BatchDataset::ValidateParams() {
return true;
}
// Function to build ConcatOp
ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
this->children = datasets_;
}
bool ConcatDataset::ValidateParams() {
if (datasets_.empty()) {
MS_LOG(ERROR) << "Concat: concatenated datasets are not specified.";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> ConcatDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
return node_ops;
}
MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns,
std::vector<std::string> output_columns, const std::vector<std::string> &project_columns)
: operations_(operations),
@ -698,9 +738,19 @@ bool TakeDataset::ValidateParams() {
}
// Function to build ZipOp
ZipDataset::ZipDataset() {}
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
for (auto dataset : datasets_) {
this->children.push_back(dataset);
}
}
bool ZipDataset::ValidateParams() { return true; }
bool ZipDataset::ValidateParams() {
if (datasets_.empty()) {
MS_LOG(ERROR) << "Zip: dataset to zip are not specified.";
return false;
}
return true;
}
std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create

@ -48,6 +48,7 @@ class MnistDataset;
class VOCDataset;
// Dataset Op classes (in alphabetical order)
class BatchDataset;
class ConcatDataset;
class MapDataset;
class ProjectDataset;
class RenameDataset;
@ -98,6 +99,14 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
/// \return Shared pointer to the current MnistDataset
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a ConcatDataset
/// \notes Reload "+" operator to concat two datasets
/// \param[in] datasets1 Shared pointer to the first dataset to be concatenated
/// \param[in] datasets2 Shared pointer to the second dataset to be concatenated
/// \return Shared pointer to the current ConcatDataset
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
const std::shared_ptr<Dataset> &datasets2);
/// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
@ -116,6 +125,12 @@ std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::strin
const std::map<std::string, int32_t> &class_index = {}, bool decode = false,
std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a ZipDataset
/// \notes Applies zip to the dataset
/// \param[in] datasets List of shared pointers to the datasets that we want to zip
/// \return Shared pointer to the current Dataset
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \class Dataset datasets.h
/// \brief A base class to represent a dataset in the data pipeline.
class Dataset : public std::enable_shared_from_this<Dataset> {
@ -158,6 +173,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current BatchDataset
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
/// \brief Function to create a ConcatDataset
/// \notes Concat the datasets in the input
/// \param[in] datasets List of shared pointers to the dataset that should be concatenated together
/// \return Shared pointer to the current ConcatDataset
std::shared_ptr<ConcatDataset> Concat(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are
@ -220,7 +241,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Function to create a Zip Dataset
/// \notes Applies zip to the dataset
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip
/// \param[in] datasets A list of shared pointers to the datasets that we want to zip
/// \return Shared pointer to the current Dataset
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
@ -377,6 +398,26 @@ class BatchDataset : public Dataset {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
};
class ConcatDataset : public Dataset {
public:
/// \brief Constructor
explicit ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Destructor
~ConcatDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};
class MapDataset : public Dataset {
public:
/// \brief Constructor
@ -521,7 +562,7 @@ class TakeDataset : public Dataset {
class ZipDataset : public Dataset {
public:
/// \brief Constructor
ZipDataset();
explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets);
/// \brief Destructor
~ZipDataset() = default;
@ -533,6 +574,9 @@ class ZipDataset : public Dataset {
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};
} // namespace api

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save