|
|
|
@ -86,9 +86,16 @@ Dataset::Dataset() {
|
|
|
|
|
// (In alphabetical order)
|
|
|
|
|
|
|
|
|
|
// Function to create a Cifar10Dataset.
|
|
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
|
|
|
|
|
std::shared_ptr<SamplerObj> sampler) {
|
|
|
|
|
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, num_samples, sampler);
|
|
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
|
|
|
|
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
|
|
|
|
|
|
|
|
|
|
// Call derived class validation method.
|
|
|
|
|
return ds->ValidateParams() ? ds : nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to create a Cifar100Dataset.
|
|
|
|
|
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
|
|
|
|
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
|
|
|
|
|
|
|
|
|
|
// Call derived class validation method.
|
|
|
|
|
return ds->ValidateParams() ? ds : nullptr;
|
|
|
|
@ -250,28 +257,27 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
|
|
|
|
|
return std::make_shared<RandomSamplerObj>(replacement, num_samples);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Helper function to validate dataset params
|
|
|
|
|
bool ValidateCommonDatasetParams(std::string dataset_dir) {
|
|
|
|
|
if (dataset_dir.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "No dataset path is specified";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/* ####################################### Derived Dataset classes ################################# */
|
|
|
|
|
|
|
|
|
|
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
|
|
|
|
|
// (In alphabetical order)
|
|
|
|
|
|
|
|
|
|
// Constructor for Cifar10Dataset
|
|
|
|
|
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
|
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
|
|
|
|
|
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
|
|
|
|
|
|
|
|
|
bool Cifar10Dataset::ValidateParams() {
|
|
|
|
|
if (dataset_dir_.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "No dataset path is specified.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (num_samples_ < 0) {
|
|
|
|
|
MS_LOG(ERROR) << "Number of samples cannot be negative";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
|
|
|
|
|
|
|
|
|
// Function to build CifarOp
|
|
|
|
|
// Function to build CifarOp for Cifar10
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
|
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
|
|
|
@ -294,6 +300,37 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
|
|
|
|
return node_ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Constructor for Cifar100Dataset
|
|
|
|
|
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
|
|
|
|
|
|
|
|
|
bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
|
|
|
|
|
|
|
|
|
// Function to build CifarOp for Cifar100
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
|
|
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
|
|
|
|
|
|
|
|
|
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
|
|
|
|
|
if (sampler_ == nullptr) {
|
|
|
|
|
sampler_ = CreateDefaultSampler();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Do internal Schema generation.
|
|
|
|
|
auto schema = std::make_unique<DataSchema>();
|
|
|
|
|
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
|
|
|
|
TensorShape scalar = TensorShape::CreateScalar();
|
|
|
|
|
RETURN_EMPTY_IF_ERROR(
|
|
|
|
|
schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
|
|
|
|
RETURN_EMPTY_IF_ERROR(
|
|
|
|
|
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
|
|
|
|
|
|
|
|
|
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_,
|
|
|
|
|
dataset_dir_, connector_que_size_, std::move(schema),
|
|
|
|
|
std::move(sampler_->Build())));
|
|
|
|
|
return node_ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
|
|
|
|
|
bool recursive, std::set<std::string> extensions,
|
|
|
|
|
std::map<std::string, int32_t> class_indexing)
|
|
|
|
@ -304,14 +341,7 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std
|
|
|
|
|
class_indexing_(class_indexing),
|
|
|
|
|
exts_(extensions) {}
|
|
|
|
|
|
|
|
|
|
bool ImageFolderDataset::ValidateParams() {
|
|
|
|
|
if (dataset_dir_.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "No dataset path is specified.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
|
|
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
|
|
|
@ -339,14 +369,7 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
|
|
|
|
|
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
|
|
|
|
|
|
|
|
|
bool MnistDataset::ValidateParams() {
|
|
|
|
|
if (dataset_dir_.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "No dataset path is specified.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
|
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
|
|
|
|