|
|
@ -21,6 +21,7 @@
|
|
|
|
#include "minddata/dataset/include/transforms.h"
|
|
|
|
#include "minddata/dataset/include/transforms.h"
|
|
|
|
#include "minddata/dataset/engine/dataset_iterator.h"
|
|
|
|
#include "minddata/dataset/engine/dataset_iterator.h"
|
|
|
|
// Source dataset headers (in alphabetical order)
|
|
|
|
// Source dataset headers (in alphabetical order)
|
|
|
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
|
|
@ -91,6 +92,16 @@ Dataset::Dataset() {
|
|
|
|
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
|
|
|
|
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
|
|
|
|
// (In alphabetical order)
|
|
|
|
// (In alphabetical order)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Function to create a CelebADataset.
|
|
|
|
|
|
|
|
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
|
|
|
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
|
|
|
|
|
|
|
const std::set<std::string> &extensions) {
|
|
|
|
|
|
|
|
auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Call derived class validation method.
|
|
|
|
|
|
|
|
return ds->ValidateParams() ? ds : nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Function to create a Cifar10Dataset.
|
|
|
|
// Function to create a Cifar10Dataset.
|
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
|
|
|
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
|
|
|
|
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
|
|
|
@ -109,7 +120,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::s
|
|
|
|
|
|
|
|
|
|
|
|
// Function to create a CocoDataset.
|
|
|
|
// Function to create a CocoDataset.
|
|
|
|
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
|
|
|
|
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
|
|
|
|
const std::string &task, bool decode, std::shared_ptr<SamplerObj> sampler) {
|
|
|
|
const std::string &task, const bool &decode,
|
|
|
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler) {
|
|
|
|
auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler);
|
|
|
|
auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler);
|
|
|
|
|
|
|
|
|
|
|
|
// Call derived class validation method.
|
|
|
|
// Call derived class validation method.
|
|
|
@ -334,6 +346,53 @@ bool ValidateCommonDatasetParams(std::string dataset_dir) {
|
|
|
|
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
|
|
|
|
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
|
|
|
|
// (In alphabetical order)
|
|
|
|
// (In alphabetical order)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Constructor for CelebADataset
|
|
|
|
|
|
|
|
CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
|
|
|
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
|
|
|
|
|
|
|
const std::set<std::string> &extensions)
|
|
|
|
|
|
|
|
: dataset_dir_(dataset_dir),
|
|
|
|
|
|
|
|
dataset_type_(dataset_type),
|
|
|
|
|
|
|
|
sampler_(sampler),
|
|
|
|
|
|
|
|
decode_(decode),
|
|
|
|
|
|
|
|
extensions_(extensions) {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool CelebADataset::ValidateParams() {
|
|
|
|
|
|
|
|
Path dir(dataset_dir_);
|
|
|
|
|
|
|
|
if (!dir.IsDirectory()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"};
|
|
|
|
|
|
|
|
auto iter = dataset_type_list.find(dataset_type_);
|
|
|
|
|
|
|
|
if (iter == dataset_type_list.end()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "dataset_type should be one of 'all', 'train', 'valid' or 'test'.";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Function to build CelebADataset
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
|
|
|
|
|
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
|
|
|
|
|
|
|
|
if (sampler_ == nullptr) {
|
|
|
|
|
|
|
|
sampler_ = CreateDefaultSampler();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
|
|
|
|
|
|
|
RETURN_EMPTY_IF_ERROR(
|
|
|
|
|
|
|
|
schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
|
|
|
|
|
|
|
// label is like this:0 1 0 0 1......
|
|
|
|
|
|
|
|
RETURN_EMPTY_IF_ERROR(
|
|
|
|
|
|
|
|
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
|
|
|
|
|
|
|
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
|
|
|
|
|
|
|
decode_, dataset_type_, extensions_, std::move(schema),
|
|
|
|
|
|
|
|
std::move(sampler_->Build())));
|
|
|
|
|
|
|
|
return node_ops;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Constructor for Cifar10Dataset
|
|
|
|
// Constructor for Cifar10Dataset
|
|
|
|
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
|
|
@ -396,7 +455,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
|
|
|
|
|
|
|
|
|
|
|
|
// Constructor for CocoDataset
|
|
|
|
// Constructor for CocoDataset
|
|
|
|
CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
|
|
|
CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
|
|
|
bool decode, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
const bool &decode, const std::shared_ptr<SamplerObj> &sampler)
|
|
|
|
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
|
|
|
|
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
|
|
|
|
|
|
|
|
|
|
|
|
bool CocoDataset::ValidateParams() {
|
|
|
|
bool CocoDataset::ValidateParams() {
|
|
|
|