|
|
|
@ -49,7 +49,6 @@
|
|
|
|
|
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
|
|
|
|
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
|
|
|
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
|
|
|
|
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
|
|
|
|
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
|
|
|
|
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
|
|
|
|
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
|
|
|
@ -143,17 +142,16 @@ void bindDatasetOps(py::module *m) {
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
|
|
|
|
|
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
|
|
|
|
|
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
|
|
|
|
|
int64_t count = 0;
|
|
|
|
|
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
|
|
|
|
|
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
|
|
|
|
|
return count;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
|
|
|
|
|
.def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
|
|
|
|
|
.def_static("get_num_rows_and_classes", [](const std::string &path) {
|
|
|
|
|
int64_t count = 0, num_classes = 0;
|
|
|
|
|
THROW_IF_ERROR(
|
|
|
|
|
ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
|
|
|
|
|
THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
|
|
|
|
|
return py::make_tuple(count, num_classes);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
@ -172,22 +170,21 @@ void bindDatasetOps(py::module *m) {
|
|
|
|
|
|
|
|
|
|
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
|
|
|
|
.def_static("get_num_rows_and_classes",
|
|
|
|
|
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
|
|
|
|
|
[](const std::string &file, const py::dict &dict, const std::string &usage) {
|
|
|
|
|
int64_t count = 0, num_classes = 0;
|
|
|
|
|
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
|
|
|
|
|
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
|
|
|
|
|
return py::make_tuple(count, num_classes);
|
|
|
|
|
})
|
|
|
|
|
.def_static("get_class_indexing",
|
|
|
|
|
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
|
|
|
|
|
std::map<std::string, int32_t> output_class_indexing;
|
|
|
|
|
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
|
|
|
|
|
return output_class_indexing;
|
|
|
|
|
});
|
|
|
|
|
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
|
|
|
|
|
std::map<std::string, int32_t> output_class_indexing;
|
|
|
|
|
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
|
|
|
|
|
return output_class_indexing;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
|
|
|
|
|
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
|
|
|
|
|
.def_static("get_num_rows", [](const std::string &dir) {
|
|
|
|
|
int64_t count = 0;
|
|
|
|
|
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
|
|
|
|
|
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
|
|
|
|
|
return count;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
@ -206,13 +203,13 @@ void bindDatasetOps(py::module *m) {
|
|
|
|
|
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
|
|
|
|
const py::dict &dict, int64_t numSamples) {
|
|
|
|
|
int64_t count = 0;
|
|
|
|
|
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
|
|
|
|
|
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
|
|
|
|
|
return count;
|
|
|
|
|
})
|
|
|
|
|
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
|
|
|
|
|
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
|
|
|
|
|
const std::string &task_mode, const py::dict &dict) {
|
|
|
|
|
std::map<std::string, int32_t> output_class_indexing;
|
|
|
|
|
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
|
|
|
|
|
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
|
|
|
|
|
return output_class_indexing;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
@ -452,25 +449,19 @@ void bindSamplerOps(py::module *m) {
|
|
|
|
|
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
|
|
|
|
|
|
|
|
|
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
|
|
|
|
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
|
|
|
|
|
py::arg("seed"));
|
|
|
|
|
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
|
|
|
|
|
|
|
|
|
|
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
|
|
|
|
|
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
|
|
|
|
|
.def(py::init<int64_t, int64_t, bool>());
|
|
|
|
|
|
|
|
|
|
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
|
|
|
|
|
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
|
|
|
|
|
py::arg("num_samples"))
|
|
|
|
|
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
|
|
|
|
|
.def(py::init<int64_t, bool, bool>());
|
|
|
|
|
|
|
|
|
|
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
|
|
|
|
|
.def(py::init<>());
|
|
|
|
|
|
|
|
|
|
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
|
|
|
|
|
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
|
|
|
|
|
.def(py::init<int64_t, int64_t>());
|
|
|
|
|
|
|
|
|
|
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
|
|
|
|
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
|
|
|
|
.def(py::init<int64_t, std::vector<int64_t>>());
|
|
|
|
|
|
|
|
|
|
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
|
|
|
|
|
*m, "MindrecordSubsetRandomSampler")
|
|
|
|
@ -487,11 +478,10 @@ void bindSamplerOps(py::module *m) {
|
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
|
|
|
|
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
|
|
|
|
py::arg("replacement"));
|
|
|
|
|
.def(py::init<int64_t, std::vector<double>, bool>());
|
|
|
|
|
|
|
|
|
|
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
|
|
|
|
|
.def(py::init<py::object>(), py::arg("pySampler"));
|
|
|
|
|
.def(py::init<int64_t, py::object>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void bindInfoObjects(py::module *m) {
|
|
|
|
|