|
|
|
@ -27,6 +27,7 @@
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/map_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/project_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
|
|
|
@ -173,6 +174,20 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
|
|
|
|
|
return ds;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to create a SkipDataset.
|
|
|
|
|
std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
|
|
|
|
|
auto ds = std::make_shared<SkipDataset>(count);
|
|
|
|
|
|
|
|
|
|
// Call derived class validation method.
|
|
|
|
|
if (!ds->ValidateParams()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ds->children.push_back(shared_from_this());
|
|
|
|
|
|
|
|
|
|
return ds;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to create a ProjectDataset.
|
|
|
|
|
std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
|
|
|
|
|
auto ds = std::make_shared<ProjectDataset>(columns);
|
|
|
|
@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Constructor for SkipDataset
|
|
|
|
|
SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}
|
|
|
|
|
|
|
|
|
|
// Function to build the SkipOp
|
|
|
|
|
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() {
|
|
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
|
|
|
|
|
|
|
|
|
node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
|
|
|
|
|
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Function to validate the parameters for SkipDataset
|
|
|
|
|
bool SkipDataset::ValidateParams() {
|
|
|
|
|
if (skip_count_ <= -1) {
|
|
|
|
|
MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Constructor for Cifar10Dataset
|
|
|
|
|
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
|
|
|
|
|
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
|
|
|
|
|