!1808 consistent design for num_samples

Merge pull request !1808 from Jamie/numsamples
pull/1808/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 769ae609b4

@ -856,9 +856,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -893,9 +891,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -930,9 +926,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -966,9 +960,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1001,9 +993,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1039,10 +1029,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
(void)builder.SetNumWorkers(ToInt(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "num_samples") {
(void)builder.SetTotalRows(ToInt(value));
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
} else if (key == "num_samples") {
// This is not sampling here. The random data op needs to know how much data to
// generate. It does not currently support sampling.
(void)builder.SetTotalRows(ToInt(value));
}
}
if (schema_exists) {
@ -1077,9 +1069,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1121,8 +1111,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetDecode(ToBool(value));
} else if (key == "extensions") {
(void)builder->SetExtensions(ToStringSet(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "dataset_type") {
(void)builder->SetDatasetType(ToString(value));
}
@ -1153,7 +1141,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
(void)builder->SetTotalRows(ToInt(value));
} else if (key == "num_shards") {
(void)builder->SetNumDevices(ToInt(value));
} else if (key == "shard_id") {

@ -49,7 +49,6 @@
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
@ -143,17 +142,16 @@ void bindDatasetOps(py::module *m) {
});
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
int64_t count = 0;
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
return count;
});
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
.def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
.def_static("get_num_rows_and_classes", [](const std::string &path) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(
ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
return py::make_tuple(count, num_classes);
});
@ -172,22 +170,21 @@ void bindDatasetOps(py::module *m) {
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes",
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
[](const std::string &file, const py::dict &dict, const std::string &usage) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
return py::make_tuple(count, num_classes);
})
.def_static("get_class_indexing",
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
return output_class_indexing;
});
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
return output_class_indexing;
});
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
.def_static("get_num_rows", [](const std::string &dir) {
int64_t count = 0;
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
return count;
});
@ -206,13 +203,13 @@ void bindDatasetOps(py::module *m) {
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
const std::string &task_mode, const py::dict &dict) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
return output_class_indexing;
});
}
@ -452,25 +449,19 @@ void bindSamplerOps(py::module *m) {
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
py::arg("seed"));
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
.def(py::init<int64_t, int64_t, bool>());
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
py::arg("num_samples"))
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
.def(py::init<int64_t, bool, bool>());
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>());
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
.def(py::init<int64_t, int64_t>());
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
.def(py::init<int64_t, std::vector<int64_t>>());
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
*m, "MindrecordSubsetRandomSampler")
@ -487,11 +478,10 @@ void bindSamplerOps(py::module *m) {
}));
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
py::arg("replacement"));
.def(py::init<int64_t, std::vector<double>, bool>());
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
.def(py::init<py::object>(), py::arg("pySampler"));
.def(py::init<int64_t, py::object>());
}
void bindInfoObjects(py::module *m) {

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr), builder_num_samples_(0) {
CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -38,7 +38,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << ".";
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
@ -47,10 +49,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
// label is like this:0 1 0 0 1......
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
*op =
std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_,
builder_decode_, builder_dataset_type_, builder_extensions_, std::move(builder_schema_),
std::move(builder_sampler_), builder_num_samples_);
*op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, builder_decode_, builder_dataset_type_,
builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_));
if (*op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null");
}
@ -68,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() {
CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
bool decode, const std::string &dataset_type, const std::set<std::string> &exts,
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler, int64_t num_samples)
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size),
rows_per_buffer_(rows_per_buffer),
folder_path_(dir),
@ -77,8 +78,6 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri
data_schema_(std::move(schema)),
sampler_(std::move(sampler)),
num_rows_in_attr_file_(0),
num_rows_exact_(0),
num_samples_(num_samples),
dataset_type_(dataset_type) {
// Set the column name map (base class field)
for (int32_t index = 0; index < data_schema_->NumColumns(); index++) {
@ -202,13 +201,6 @@ Status CelebAOp::ParseImageAttrInfo() {
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
while (!image_infos.empty() && needMoreData) {
for (uint32_t index = 0; index < image_infos.size(); index++) {
if (num_samples_ != 0 && image_labels_vec_.size() >= num_samples_) {
MS_LOG(WARNING) << "Image number(" << image_labels_vec_.size() << " is more than"
<< " rows num eval attr file(" << num_rows_in_attr_file_ << ") or num samples(" << num_samples_
<< ").";
needMoreData = false;
break;
}
std::string image_info = image_infos[index];
std::vector<std::string> split = Split(image_info);
std::pair<std::string, std::vector<int32_t>> image_labels;
@ -239,14 +231,13 @@ Status CelebAOp::ParseImageAttrInfo() {
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
}
num_rows_exact_ = image_labels_vec_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_exact_) ? num_rows_exact_ : num_samples_;
if (num_rows_exact_ == 0) {
num_rows_ = image_labels_vec_.size();
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first.");
}
MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_exact_ << ".";
MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << ".";
return Status::OK();
}
@ -268,28 +259,6 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
return split;
}
// Derived from RandomAccessOp
Status CelebAOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
Status CelebAOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_exact_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first.");
}
*num = num_rows_exact_;
return Status::OK();
}
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
Status CelebAOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
@ -310,9 +279,8 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_exact_) {
MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_exact_
<< ".";
if ((*itr) >= num_rows_) {
MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << ".";
continue;
}
keys.push_back(*itr);
@ -446,7 +414,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows:" << num_rows_exact_ << "\nceleba dir: " << folder_path_ << "\n\n";
out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n";
}
}

@ -108,14 +108,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param const std::string dataset_type: type to be read
// @return Builder setter method returns reference to the builder.
@ -141,7 +133,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
int64_t builder_num_samples_;
std::string builder_dataset_type_;
};
@ -153,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode,
const std::string &dataset_type, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
std::shared_ptr<Sampler> sampler, int64_t num_samples);
std::shared_ptr<Sampler> sampler);
~CelebAOp() override = default;
@ -163,16 +154,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker
// @return Status - The error code return
@ -233,11 +214,9 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std::shared_ptr<Sampler> sampler_;
std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_;
int64_t num_rows_in_attr_file_; // rows number specified in attr file
int64_t num_rows_exact_; // exact rows number,maybe is less than rows_num_in_attr_file_
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
WaitPost wp_;
std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_;
int64_t num_samples_;
std::string dataset_type_;
std::ifstream partition_file_;
};

@ -35,7 +35,7 @@ constexpr uint32_t kCifarImageChannel = 3;
constexpr uint32_t kCifarBlockImageNum = 5;
constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel;
CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) {
CifarOp::Builder::Builder() : sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
@ -46,7 +46,9 @@ CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) {
Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (sampler_ == nullptr) {
sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
@ -62,7 +64,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar)));
}
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, num_samples_,
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_,
std::move(schema_), std::move(sampler_));
return Status::OK();
}
@ -76,16 +78,13 @@ Status CifarOp::Builder::SanityCheck() {
}
CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir,
int32_t queue_size, int64_t num_samples, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_works, queue_size),
cifar_type_(type),
rows_per_buffer_(rows_per_buf),
folder_path_(file_dir),
num_samples_(num_samples),
data_schema_(std::move(data_schema)),
sampler_(std::move(sampler)),
num_rows_(0),
row_cnt_(0),
buf_cnt_(0) {
// set the column name map (base class field)
@ -112,8 +111,7 @@ Status CifarOp::operator()() {
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
row_cnt_++;
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ % rows_per_buffer_ == 0) {
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
@ -255,30 +253,6 @@ Status CifarOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status CifarOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
std::string err_msg = "There is no valid data matching the dataset API " + api +
".Please check file path or dataset API validation first.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status CifarOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
std::string err_msg = "There is no valid data matching the dataset API " + api +
".Please check file path or dataset API validation first.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
(*num) = num_rows_;
return Status::OK();
}
Status CifarOp::ReadCifarBlockDataAsync() {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(GetCifarFiles());
@ -404,7 +378,6 @@ Status CifarOp::ParseCifarData() {
}
cifar_image_label_pairs_.shrink_to_fit();
num_rows_ = cifar_image_label_pairs_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
if (num_rows_ == 0) {
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
std::string err_msg = "There is no valid data matching the dataset API " + api +
@ -432,11 +405,11 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
return Status::OK();
}
Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count) {
Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) {
// the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block()
std::shared_ptr<CifarOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetNumSamples(numSamples).SetCifarType(isCIFAR10).Build(&op));
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op));
RETURN_IF_NOT_OK(op->GetCifarFiles());
if (op->cifar_type_ == kCifar10) {
constexpr int64_t num_cifar10_records = 10000;
@ -448,7 +421,6 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
}
*count = *count + num_cifar10_records;
}
*count = *count < numSamples || numSamples == 0 ? *count : numSamples;
return Status::OK();
} else {
int64_t num_cifar100_records = 0;
@ -470,7 +442,7 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
*count = num_cifar100_records < numSamples || numSamples == 0 ? num_cifar100_records : numSamples;
*count = num_cifar100_records;
return Status::OK();
}
}

@ -73,14 +73,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param uint64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(uint64_t num_samples) {
num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -121,7 +113,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
private:
std::string dir_;
int32_t num_workers_;
uint64_t num_samples_;
int32_t rows_per_buffer_;
int32_t op_connect_size_;
std::shared_ptr<Sampler> sampler_;
@ -137,7 +128,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @param uint32_t - queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size,
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
// Destructor.
~CifarOp() = default;
@ -152,16 +143,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param uint64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param uint64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// A print method typically used for debugging
// @param out
// @param show_all
@ -169,11 +150,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Function to count the number of samples in the CIFAR dataset
// @param dir path to the CIFAR directory
// @param numSamples maximum number of samples requested
// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
// @param count output arg that will hold the actual dataset size
// @return
static Status CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count);
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
private:
// Initialize Sampler, calls sampler->Init() within
@ -227,10 +207,8 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
CifarType cifar_type_;
int32_t rows_per_buffer_;
std::string folder_path_;
int64_t num_samples_;
std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t num_rows_;
int64_t row_cnt_;
int64_t buf_cnt_;

@ -26,8 +26,7 @@
namespace mindspore {
namespace dataset {
ImageFolderOp::Builder::Builder()
: builder_decode_(false), builder_recursive_(false), builder_num_samples_(0), builder_sampler_(nullptr) {
ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -37,7 +36,9 @@ ImageFolderOp::Builder::Builder()
Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
@ -46,9 +47,9 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
*ptr = std::make_shared<ImageFolderOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, builder_num_samples_, builder_recursive_,
builder_decode_, builder_extensions_, builder_labels_to_read_,
std::move(builder_schema_), std::move(builder_sampler_));
builder_op_connector_size_, builder_recursive_, builder_decode_,
builder_extensions_, builder_labels_to_read_, std::move(builder_schema_),
std::move(builder_sampler_));
return Status::OK();
}
@ -61,20 +62,18 @@ Status ImageFolderOp::Builder::SanityCheck() {
}
ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts,
bool recursive, bool do_decode, const std::set<std::string> &exts,
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
: ParallelOp(num_wkrs, queue_size),
rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir),
num_samples_(num_samples),
recursive_(recursive),
decode_(do_decode),
extensions_(exts),
class_index_(map),
data_schema_(std::move(data_schema)),
sampler_(std::move(sampler)),
num_rows_(0),
row_cnt_(0),
buf_cnt_(0),
sampler_ind_(0),
@ -117,7 +116,6 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
}
image_label_pairs_.shrink_to_fit();
num_rows_ = image_label_pairs_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
// free memory of two queues used for pre-scan
folder_name_queue_->Reset();
image_name_queue_->Reset();
@ -138,8 +136,7 @@ Status ImageFolderOp::operator()() {
std::shared_ptr<Tensor> sample_ids = sample_row[0];
if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64");
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
keys.push_back(*itr);
row_cnt_++;
if (row_cnt_ % rows_per_buffer_ == 0) {
@ -272,28 +269,6 @@ Status ImageFolderOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status ImageFolderOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ImageFolderOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
"validation first.");
}
(*num) = num_rows_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ImageFolderOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
@ -413,16 +388,14 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() {
return Status::OK();
}
Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t &num_samples,
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
int64_t dev_id, int64_t num_dev) {
Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
int64_t *num_classes, int64_t dev_id, int64_t num_dev) {
Path dir(path);
std::string err_msg = "";
int64_t row_cnt = 0;
err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : "";
err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : "";
err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : "";
err_msg += num_samples < 0 ? "num_samples can't be negative! set it to 0 to use all samples\n" : "";
if (err_msg.empty() == false) {
RETURN_STATUS_UNEXPECTED(err_msg);
}
@ -441,10 +414,6 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
while (dir_itr->hasNext()) {
if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) {
++row_cnt;
if (row_cnt == num_samples * num_dev) {
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
return Status::OK();
}
}
}
foldernames.pop();

@ -107,14 +107,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -153,7 +145,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
bool builder_recursive_;
std::string builder_dir_;
int32_t builder_num_workers_;
int64_t builder_num_samples_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_;
@ -169,10 +160,9 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @param int32_t queue_size - connector queue size
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts,
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema>,
std::shared_ptr<Sampler> sampler);
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive,
bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler);
// Destructor.
~ImageFolderOp() = default;
@ -198,16 +188,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
@ -221,9 +201,8 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result
// returned by this function may not be consistent with what image_folder_op is going to return
// user this at your own risk!
static Status CountRowsAndClasses(const std::string &path, const int64_t &num_samples,
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
int64_t dev_id = 0, int64_t num_dev = 1);
static Status CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
@ -266,14 +245,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
int32_t rows_per_buffer_;
std::string folder_path_; // directory of image folder
int64_t num_samples_;
bool recursive_;
bool decode_;
std::set<std::string> extensions_; // extensions allowed
std::map<std::string, int32_t> class_index_;
std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t num_rows_; // total number of images in ImageFolder
int64_t row_cnt_;
int64_t buf_cnt_;
int64_t sampler_ind_;

@ -29,7 +29,7 @@
namespace mindspore {
namespace dataset {
ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_(0), builder_decode_(false) {
ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -39,16 +39,18 @@ ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_
Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
*ptr = std::make_shared<ManifestOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_file_, builder_op_connector_size_, builder_num_samples_,
builder_decode_, builder_labels_to_read_, std::move(builder_schema_), std::move(builder_sampler_), builder_usage_);
*ptr = std::make_shared<ManifestOp>(builder_num_workers_, builder_rows_per_buffer_, builder_file_,
builder_op_connector_size_, builder_decode_, builder_labels_to_read_,
std::move(builder_schema_), std::move(builder_sampler_), builder_usage_);
return Status::OK();
}
@ -59,9 +61,9 @@ Status ManifestOp::Builder::SanityCheck() {
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size,
int64_t num_samples, bool decode, const std::map<std::string, int32_t> &class_index,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler, std::string usage)
ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage)
: ParallelOp(num_works, queue_size),
rows_per_buffer_(rows_per_buffer),
io_block_pushed_(0),
@ -71,8 +73,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
file_(file),
class_index_(class_index),
sampler_(std::move(sampler)),
num_samples_(num_samples),
num_rows_(0),
decode_(decode),
usage_(usage),
buf_cnt_(0) {
@ -101,8 +101,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
keys.push_back(*itr);
row_cnt_++;
if (row_cnt_ % rows_per_buffer_ == 0) {
@ -269,28 +268,6 @@ Status ManifestOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status ManifestOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ManifestOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_rows_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ManifestOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) {
@ -408,7 +385,6 @@ Status ManifestOp::CountDatasetInfo() {
}
num_rows_ = static_cast<int64_t>(image_labelname_.size());
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
@ -417,8 +393,8 @@ Status ManifestOp::CountDatasetInfo() {
return Status::OK();
}
Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, int64_t *count, int64_t *numClasses) {
Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage,
int64_t *count, int64_t *numClasses) {
// the logic of counting the number of samples is copied from ParseManifestFile()
std::map<std::string, int32_t> map;
for (auto p : dict) {
@ -428,17 +404,15 @@ Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, c
std::shared_ptr<ManifestOp> op;
*count = 0;
RETURN_IF_NOT_OK(
Builder().SetManifestFile(file).SetNumSamples(numSamples).SetClassIndex(map).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
*numClasses = static_cast<int64_t>(op->label_index_.size());
*count = static_cast<int64_t>(op->image_labelname_.size());
*count = (*count < numSamples || numSamples == 0) ? *count : numSamples;
return Status::OK();
}
Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, std::map<std::string, int32_t> *output_class_indexing) {
Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing) {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
@ -449,12 +423,7 @@ Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples,
*output_class_indexing = input_class_indexing;
} else {
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder()
.SetManifestFile(file)
.SetNumSamples(numSamples)
.SetClassIndex(input_class_indexing)
.SetUsage(usage)
.Build(&op));
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
RETURN_IF_NOT_OK(op->CountDatasetInfo());
uint32_t count = 0;

@ -86,14 +86,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -129,7 +121,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
private:
std::shared_ptr<Sampler> builder_sampler_;
int64_t builder_num_samples_;
bool builder_decode_;
std::string builder_file_;
@ -147,8 +138,8 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param std::string - file list of Manifest
// @param int32_t queue_size - connector queue size
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, int64_t num_samples,
bool decode, const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage);
// Destructor.
~ManifestOp() = default;
@ -164,16 +155,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total number of Rows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
@ -184,12 +165,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;
static Status CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, int64_t *count, int64_t *numClasses);
static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count,
int64_t *numClasses);
// Get str-to-int mapping from label name to index
static Status GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, std::map<std::string, int32_t> *output_class_indexing);
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing);
private:
// Initialize Sampler, calls sampler->Init() within
@ -240,8 +221,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
std::string file_; // file that store the information of images
std::map<std::string, int32_t> class_index_;
std::shared_ptr<Sampler> sampler_;
int64_t num_samples_;
int64_t num_rows_;
bool decode_;
std::string usage_;
int64_t buf_cnt_;

@ -91,7 +91,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
block_reader_(block_reader),
buffers_needed_(0),
buf_cnt_(0),
num_rows_(0),
ended_worker_(0),
buffer_water_mark_(0) {
io_blk_queues_.Init(num_workers_, op_connector_queue_size);

@ -31,7 +31,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
const int32_t kMnistImageRows = 28;
const int32_t kMnistImageCols = 28;
MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr) {
MnistOp::Builder::Builder() : builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -41,7 +41,9 @@ MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr)
Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
@ -49,9 +51,8 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
*ptr =
std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_,
builder_num_samples_, std::move(builder_schema_), std::move(builder_sampler_));
*ptr = std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_));
return Status::OK();
}
@ -60,17 +61,14 @@ Status MnistOp::Builder::SanityCheck() {
std::string err_msg;
err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : "";
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : "";
err_msg += builder_num_samples_ < 0 ? "Number of samples is set to negative\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size),
buf_cnt_(0),
row_cnt_(0),
num_rows_(0),
num_samples_(num_samples),
folder_path_(folder_path),
rows_per_buffer_(rows_per_buffer),
sampler_(std::move(sampler)),
@ -84,8 +82,7 @@ MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folde
Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) {
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
keys->push_back(*itr);
row_cnt_++;
if (row_cnt_ % rows_per_buffer_ == 0) {
@ -219,17 +216,6 @@ Status MnistOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status MnistOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
@ -364,7 +350,6 @@ Status MnistOp::ParseMnistData() {
}
image_label_pairs_.shrink_to_fit();
num_rows_ = image_label_pairs_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
return Status::OK();
}
@ -414,11 +399,11 @@ Status MnistOp::LaunchThreadsAndInitOp() {
return Status::OK();
}
Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count) {
Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
std::shared_ptr<MnistOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetNumSamples(numSamples).Build(&op));
RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op));
RETURN_IF_NOT_OK(op->WalkAllFiles());
@ -440,19 +425,6 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64
label_reader.close();
}
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;
return Status::OK();
}
// Derived from RandomAccessOp
Status MnistOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_rows_;
return Status::OK();
}
} // namespace dataset

@ -78,14 +78,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -114,7 +106,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
private:
std::string builder_dir_;
int32_t builder_num_workers_;
int64_t builder_num_samples_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::shared_ptr<Sampler> builder_sampler_;
@ -126,11 +117,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param std::string folder_path - dir directory of mnist
// @param int32_t queue_size - connector queue size
// @param int64_t num_samples - number of samples to read
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
// Destructor.
~MnistOp() = default;
@ -146,16 +136,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
@ -167,11 +147,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
void Print(std::ostream &out, bool show_all) const override;
// Function to count the number of samples in the MNIST dataset
// @param dir path to the MNSIT directory
// @param numSamples maximum number of samples requested
// @param dir path to the MNIST directory
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
// @return
static Status CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count);
static Status CountTotalRows(const std::string &dir, int64_t *count);
private:
// Initialize Sampler, calls sampler->Init() within
@ -244,9 +223,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
int64_t buf_cnt_;
int64_t row_cnt_;
int64_t num_rows_; // total number of images in Mnist
WaitPost wp_;
int64_t num_samples_;
std::string folder_path_; // directory of image folder
int32_t rows_per_buffer_;
std::shared_ptr<Sampler> sampler_;

@ -8,6 +8,5 @@ add_library(engine-datasetops-source-sampler OBJECT
sampler.cc
sequential_sampler.cc
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc
)

@ -23,8 +23,9 @@
namespace mindspore {
namespace dataset {
DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle, uint32_t seed)
: Sampler(),
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed)
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
cnt_(0),
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
device_id_(dev_id),
@ -32,6 +33,11 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
shuffle_(shuffle) {}
Status DistributedSampler::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,

@ -27,10 +27,11 @@ namespace mindspore {
namespace dataset {
class DistributedSampler : public Sampler {
public:
// @param int64_t numDev
// @param int64_t devId
// @param num_samples
// @param int64_t num_dev
// @param int64_t dev_id
// @param bool shuffle
DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle = true,
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max());
// default destructor

@ -20,12 +20,11 @@
namespace mindspore {
namespace dataset {
PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
: Sampler(samples_per_buffer),
PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
shuffle_(shuffle),
seed_(GetSeed()),
next_id_(0),
num_pk_samples_(0),
samples_per_class_(val) {}
Status PKSampler::InitSampler() {
@ -36,22 +35,34 @@ Status PKSampler::InitSampler() {
}
}
rnd_.seed(seed_++);
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
num_samples_ = num_pk_samples_;
// The special handshake gives the list of classes and id's, but it did not set the num_rows_ to
// capture the total number of possible sample ids.
// Compute that here for this case to find the total number of samples that are available to return.
// (in this case, samples per class * total classes).
num_rows_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
// The user may have chosen to sample less than the total amount.
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_;
if (shuffle_ == true) {
std::shuffle(labels_.begin(), labels_.end(), rnd_);
} else {
std::sort(labels_.begin(), labels_.end());
}
CHECK_FAIL_RETURN_UNEXPECTED(num_pk_samples_ > 0, "num_class or K (num samples per class) is not positive");
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive");
return Status::OK();
}
Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_pk_samples_ || num_pk_samples_ == 0) {
if (next_id_ > num_samples_ || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler");
} else if (next_id_ == num_pk_samples_) {
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
@ -60,8 +71,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids;
int64_t last_id =
(samples_per_buffer_ + next_id_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_ + next_id_;
int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
while (next_id_ < last_id) {
@ -85,7 +95,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status PKSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
rnd_.seed(seed_++);

@ -28,10 +28,11 @@ namespace mindspore {
namespace dataset {
class PKSampler : public Sampler { // NOT YET FINISHED
public:
// @param int64_t kVal
// @param num_samples - the number of samples to draw. value of 0 means to take the full amount
// @param int64_t val
// @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PKSampler(int64_t val, bool shuffle = false,
explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// default destructor
@ -42,8 +43,9 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
// first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @return
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
@ -58,7 +60,6 @@ class PKSampler : public Sampler { // NOT YET FINISHED
bool shuffle_;
uint32_t seed_;
int64_t next_id_;
int64_t num_pk_samples_;
int64_t samples_per_class_;
std::mt19937 rnd_;
std::vector<int64_t> labels_;

@ -20,8 +20,8 @@
namespace mindspore {
namespace dataset {
PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
@ -65,6 +65,11 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status PythonSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0");
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {

@ -26,8 +26,11 @@ namespace dataset {
class PythonSampler : public Sampler {
public:
// Constructor
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(py::object py_sampler_instance,
// @param num_samples - the number of samples to draw. Value of 0 means to sample all of the
// data from the dataset.
// @param py_sampler_instance - the python instance of the sampler
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.

@ -22,12 +22,11 @@
namespace mindspore {
namespace dataset {
RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer)
: Sampler(samples_per_buffer),
: Sampler(num_samples, samples_per_buffer),
seed_(GetSeed()),
replacement_(replacement),
user_num_samples_(num_samples),
next_id_(0),
reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {}
@ -70,27 +69,25 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status RandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
rnd_.seed(seed_);
if (replacement_ == false) {
num_samples_ = std::min(num_samples_, num_rows_);
num_samples_ = std::min(num_samples_, user_num_samples_);
shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {
shuffled_ids_.push_back(i);
}
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
} else {
num_samples_ = std::min(num_samples_, user_num_samples_);
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
}
@ -119,7 +116,6 @@ void RandomSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): RandomSampler\n";
if (show_all) {
out << "user_num_samples_: " << user_num_samples_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
out << "next_id_: " << next_id_ << '\n';
}

@ -27,11 +27,11 @@ namespace dataset {
class RandomSampler : public Sampler {
public:
// Constructor
// @param int64_t num_samples - number samples to draw
// @param bool replacement - put he id back / or not after a sample
// @param int64_t numSamples - number samples to draw
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
int64_t num_samples = std::numeric_limits<int64_t>::max(),
// @param reshuffle_each_epoch - T/F to reshuffle after epoch
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.
@ -55,7 +55,6 @@ class RandomSampler : public Sampler {
private:
uint32_t seed_;
bool replacement_;
int64_t user_num_samples_;
std::vector<int64_t> shuffled_ids_; // only used for NO REPLACEMENT
int64_t next_id_;
std::mt19937 rnd_;

@ -19,8 +19,25 @@
namespace mindspore {
namespace dataset {
Sampler::Sampler(int64_t samples_per_buffer)
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
// The sampler base class itself does not compute it's own num_rows_ value.
// Instead, this value is computed by the derived leaf op during it's own initialization
// after it has interacted with it's storage layers.
// Here, it is just a getter method to return the value. However, it is invalid if there is
// not a value set for this count, so generate a failure if that is the case.
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet.");
}
(*num) = num_rows_;
return Status::OK();
}
Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer)
: DatasetOp(0),
num_rows_(0),
num_samples_(num_samples),
samples_per_buffer_(samples_per_buffer),
col_desc_(nullptr) {}
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
@ -36,10 +53,10 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
// If there's a child sampler, set the row count to be it's sample count
if (HasChildSampler()) {
int64_t child_num_samples = child_sampler->num_samples();
num_rows_ = child_num_samples;
num_rows_ = child_sampler->num_samples_;
} else {
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
}
@ -105,7 +122,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
}
Status Sampler::SetNumSamples(int64_t num_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0");
CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative");
num_samples_ = num_samples;
return Status::OK();
}
@ -116,6 +133,16 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
return Status::OK();
}
// inline op doesn't have it's own consumer, it's assigned from parent
int32_t Sampler::num_consumers() const {
if (parent_.empty() || parent_[0] == nullptr) {
MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0.";
return 0;
} else {
return parent_[0]->num_consumers();
}
}
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
if (child == nullptr) {
return Status::OK();
@ -155,5 +182,14 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
return Status::OK();
}
// inline op doesn't have it's own producers, it's assigned from child
int32_t Sampler::num_producers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(WARNING) << "Sampler with no child, num_producers is 0.";
return 0;
} else {
return child_[0]->num_producers();
}
}
} // namespace dataset
} // namespace mindspore

@ -33,23 +33,10 @@ namespace dataset {
// must inherit from if those leaf operator wish to support sampling.
class RandomAccessOp {
public:
// Sampler get numRows from StorageOp
// @param int64_t num - return number of rows, normally num of samples
// @return - The error code return
virtual Status GetNumSamples(int64_t *num_samples) const {
// CI complains num_samples not used if the following line is not added
CHECK_FAIL_RETURN_UNEXPECTED(num_samples != nullptr, "num_samples == nullptr");
RETURN_STATUS_UNEXPECTED("function GetNumSamples needs to overridden to support this sampler");
}
// Sampler get number of rows in the dataset!
// Sampler get number of rows in the dataset
// @param int64_t num - return number of rows for this dataset
// @return - The error code return
virtual Status GetNumRowsInDataset(int64_t *num_rows) const {
// CI complains num_rows not used if the following line is not added
CHECK_FAIL_RETURN_UNEXPECTED(num_rows != nullptr, "num_rows == nullptr");
RETURN_STATUS_UNEXPECTED("function GetNumRowsInDataset needs to overridden to support this sampler");
}
Status GetNumRowsInDataset(int64_t *num_rows) const;
// sampler gets label , imageIds from storageOp, this function is unique to PK
// @param std::map<int64_t, std::vector<int64_t>> * map
@ -60,12 +47,20 @@ class RandomAccessOp {
// default destructor
virtual ~RandomAccessOp() = default;
protected:
// The amount of rows in the dataset itself. This is the before-sampling value, the
// total count of rows. A sampler may choose to sample less than this amount.
int64_t num_rows_;
};
class Sampler : public DatasetOp {
public:
// Constructor
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
// indicates that the sampler should produce the complete set of ids.
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit Sampler(int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);
// default destructor
~Sampler() = default;
@ -84,33 +79,36 @@ class Sampler : public DatasetOp {
// @return - The error code return
Status Reset() override = 0;
// setter function for num_rows_
Status SetNumRowsInDataset(int64_t num_rows);
// setter function for num_samples_
Status SetNumSamples(int64_t num_samples);
int64_t num_samples() { return num_samples_; }
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @return
virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op);
// initialize sampler and perform checks on certain vars
virtual Status InitSampler() { return Status::OK(); }
// Not meant to be called
// setter for num samples
// @param num_samples - the number of samples to assign.
// @return status error code
Status SetNumSamples(int64_t num_samples);
// setter for num or records in the dataset
// @param num_rows - the number of records
// @return status error code
Status SetNumRowsInDataset(int64_t num_rows);
// Sampler is an inlined op and has no workers. Producers and consumers are computed.
// @return
int32_t num_workers() const final { return 0; }
// Not meant to be called
// Identify num consumers (inlined op)
// @return
int32_t num_consumers() const final { return 0; }
int32_t num_consumers() const final;
// Not meant to be called
// Identify num producers (inlined op)
// @return
int32_t num_producers() const final { return 0; }
int32_t num_producers() const final;
// Not meant to be called!
// @return - The error code return
@ -151,10 +149,11 @@ class Sampler : public DatasetOp {
// output. Otherwise, num_rows_ is the number of rows in the dataset.
int64_t num_rows_;
// Number of ids this sampler will return.
// The user may want to sample less than the full amount of data. num_samples_ reduces the number
// of id's returned as request by the user. Derived classes will choose how to sample the smaller
// amount.
int64_t num_samples_;
// The max number of ids a DataBuffer returned by this sampler will contain.
int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::unique_ptr<DataBuffer> child_ids_;

@ -20,34 +20,42 @@
namespace mindspore {
namespace dataset {
SequentialSampler::SequentialSampler(int64_t samples_per_buffer) : Sampler(samples_per_buffer), next_id_(0) {}
SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("Sequential Sampler Internal Error");
} else if (next_id_ == num_samples_) {
if (id_count_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error");
} else if (id_count_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
(*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds;
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
// Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for
// samples per buffer though.
int64_t remaining_ids = num_samples_ - id_count_;
int64_t num_elements = std::min(remaining_ids, samples_per_buffer_);
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements));
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
while (next_id_ < lastId) {
int64_t sampled_id = next_id_;
for (int64_t i = 0; i < num_elements; i++) {
int64_t sampled_id = current_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*idPtr = sampled_id;
next_id_++;
current_id_++; // Move the current id to the next one in the sequence
idPtr++;
}
id_count_ += num_elements; // Count the packed ids towards our overall sample count
TensorRow row(1, sampleIds);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
@ -55,19 +63,24 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
}
Status SequentialSampler::InitSampler() {
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
if (HasChildSampler()) {
num_samples_ = std::min(num_samples_, num_rows_);
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n");
// Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample
// the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data.
int64_t available_row_count = num_rows_ - start_index_;
if (num_samples_ == 0 || num_samples_ > available_row_count) {
num_samples_ = available_row_count;
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
}
Status SequentialSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late");
current_id_ = start_index_;
id_count_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save