!3438 C++ API support for TakeDatasetOp and VOCDatasetOp

Merge pull request !3438 from luoyang/pylint
pull/3438/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f2401cd0f9

@ -24,21 +24,25 @@
#include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
// Dataset operator headers (in alphabetical order) // Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h" #include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h" #include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h" #include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
// Sampler headers (in alphabetical order) // Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/path.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -123,6 +127,16 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
} }
// Function to create a VOCDataset.
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode,
std::shared_ptr<SamplerObj> sampler) {
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_index, decode, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// FUNCTIONS TO CREATE DATASETS FOR DATASET OPS // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
// (In alphabetical order) // (In alphabetical order)
@ -232,6 +246,26 @@ std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
return ds; return ds;
} }
// Function to create a TakeDataset.
std::shared_ptr<Dataset> Dataset::Take(int32_t count) {
// If count is greater than the number of element in dataset or equal to -1,
// all the element in dataset will be taken
if (count == -1) {
return shared_from_this();
}
auto ds = std::make_shared<TakeDataset>(count);
// Call derived class validation method.
if (!ds->ValidateParams()) {
return nullptr;
}
ds->children.push_back(shared_from_this());
return ds;
}
// Function to create a Zip dataset // Function to create a Zip dataset
std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
// Default values // Default values
@ -392,6 +426,71 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
return node_ops; return node_ops;
} }
// Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode,
std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir),
task_(task),
mode_(mode),
class_index_(class_index),
decode_(decode),
sampler_(sampler) {}
bool VOCDataset::ValidateParams() {
Path dir(dataset_dir_);
if (!dir.IsDirectory()) {
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
return false;
}
if (task_ == "Segmentation") {
if (!class_index_.empty()) {
MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task.";
return false;
}
Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt";
if (!imagesets_file.Exists()) {
MS_LOG(ERROR) << "[Segmentation] imagesets_file is invalid or not exist";
return false;
}
} else if (task_ == "Detection") {
Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt";
if (!imagesets_file.Exists()) {
MS_LOG(ERROR) << "[Detection] imagesets_file is invalid or not exist.";
return false;
}
} else {
MS_LOG(ERROR) << "Invalid task: " << task_;
return false;
}
return true;
}
// Function to build VOCDataset
std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
if (sampler_ == nullptr) {
sampler_ = CreateDefaultSampler();
}
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
(void)builder->SetDir(dataset_dir_);
(void)builder->SetTask(task_);
(void)builder->SetMode(mode_);
(void)builder->SetNumWorkers(num_workers_);
(void)builder->SetSampler(std::move(sampler_->Build()));
(void)builder->SetDecode(decode_);
(void)builder->SetClassIndex(class_index_);
std::shared_ptr<VOCOp> op;
RETURN_EMPTY_IF_ERROR(builder->Build(&op));
node_ops.push_back(op);
return node_ops;
}
// DERIVED DATASET CLASSES LEAF-NODE DATASETS // DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order) // (In alphabetical order)
@ -580,6 +679,28 @@ bool SkipDataset::ValidateParams() {
return true; return true;
} }
// Constructor for TakeDataset
TakeDataset::TakeDataset(int32_t count) : take_count_(count) {}
// Function to build the TakeOp
std::vector<std::shared_ptr<DatasetOp>> TakeDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
return node_ops;
}
// Function to validate the parameters for TakeDataset
bool TakeDataset::ValidateParams() {
if (take_count_ < -1) {
MS_LOG(ERROR) << "Take: Invalid input, take_count: " << take_count_;
return false;
}
return true;
}
// Function to build ZipOp // Function to build ZipOp
ZipDataset::ZipDataset() {} ZipDataset::ZipDataset() {}

@ -45,6 +45,7 @@ class Cifar10Dataset;
class Cifar100Dataset; class Cifar100Dataset;
class ImageFolderDataset; class ImageFolderDataset;
class MnistDataset; class MnistDataset;
class VOCDataset;
// Dataset Op classes (in alphabetical order) // Dataset Op classes (in alphabetical order)
class BatchDataset; class BatchDataset;
class MapDataset; class MapDataset;
@ -53,6 +54,7 @@ class RenameDataset;
class RepeatDataset; class RepeatDataset;
class ShuffleDataset; class ShuffleDataset;
class SkipDataset; class SkipDataset;
class TakeDataset;
class ZipDataset; class ZipDataset;
/// \brief Function to create a Cifar10 Dataset /// \brief Function to create a Cifar10 Dataset
@ -96,6 +98,24 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
/// \return Shared pointer to the current MnistDataset /// \return Shared pointer to the current MnistDataset
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr); std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
/// ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
/// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
/// \param[in] mode Set the data list txt file to be readed
/// \param[in] class_indexing A str-to-int mapping from label name to index
/// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
const std::string &mode = "train",
const std::map<std::string, int32_t> &class_index = {}, bool decode = false,
std::shared_ptr<SamplerObj> sampler = nullptr);
/// \class Dataset datasets.h /// \class Dataset datasets.h
/// \brief A base class to represent a dataset in the data pipeline. /// \brief A base class to represent a dataset in the data pipeline.
class Dataset : public std::enable_shared_from_this<Dataset> { class Dataset : public std::enable_shared_from_this<Dataset> {
@ -192,6 +212,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current SkipDataset /// \return Shared pointer to the current SkipDataset
std::shared_ptr<SkipDataset> Skip(int32_t count); std::shared_ptr<SkipDataset> Skip(int32_t count);
/// \brief Function to create a TakeDataset
/// \notes Takes count elements in this dataset.
/// \param[in] count Number of elements the dataset to be taken.
/// \return Shared pointer to the current Dataset
std::shared_ptr<Dataset> Take(int32_t count = -1);
/// \brief Function to create a Zip Dataset /// \brief Function to create a Zip Dataset
/// \notes Applies zip to the dataset /// \notes Applies zip to the dataset
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip /// \param[in] datasets A list of shared pointer to the datasets that we want to zip
@ -300,6 +326,32 @@ class MnistDataset : public Dataset {
std::shared_ptr<SamplerObj> sampler_; std::shared_ptr<SamplerObj> sampler_;
}; };
class VOCDataset : public Dataset {
public:
/// \brief Constructor
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~VOCDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::string dataset_dir_;
std::string task_;
std::string mode_;
std::map<std::string, int32_t> class_index_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
class BatchDataset : public Dataset { class BatchDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
@ -446,6 +498,26 @@ class SkipDataset : public Dataset {
int32_t skip_count_; int32_t skip_count_;
}; };
class TakeDataset : public Dataset {
public:
/// \brief Constructor
explicit TakeDataset(int32_t count);
/// \brief Destructor
~TakeDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
int32_t take_count_;
};
class ZipDataset : public Dataset { class ZipDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save