!11784 Change samplers binding in Python to SamplerObj
From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:pull/11784/MERGE
commit
5f0f9da6c6
@ -1,93 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
|
||||
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
|
||||
.def("set_num_rows",
|
||||
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
|
||||
.def("set_num_samples",
|
||||
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
|
||||
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
|
||||
.def("get_indices",
|
||||
[](SamplerRT &self) {
|
||||
py::array ret;
|
||||
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
||||
return ret;
|
||||
})
|
||||
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
|
||||
THROW_IF_ERROR(self->AddChild(child));
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
|
||||
*m, "DistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
|
||||
.def(py::init<int64_t, int64_t, bool>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
|
||||
.def(py::init<int64_t, py::object>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
|
||||
.def(py::init<int64_t, bool, bool>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
|
||||
*m, "SequentialSampler")
|
||||
.def(py::init<int64_t, int64_t>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetRandomSamplerRT, SubsetSamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
|
||||
*m, "SubsetRandomSampler")
|
||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetSamplerRT, SamplerRT, std::shared_ptr<SubsetSamplerRT>>(*m, "SubsetSampler")
|
||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
|
||||
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
|
||||
*m, "WeightedRandomSampler")
|
||||
.def(py::init<int64_t, std::vector<double>, bool>());
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,127 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SamplerObj, std::shared_ptr<SamplerObj>>(*m, "SamplerObj", "to create a SamplerObj")
|
||||
.def("add_child", [](std::shared_ptr<SamplerObj> self, std::shared_ptr<SamplerObj> child) {
|
||||
THROW_IF_ERROR(self->AddChildSampler(child));
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<DistributedSamplerObj, SamplerObj, std::shared_ptr<DistributedSamplerObj>>(
|
||||
*m, "DistributedSamplerObj", "to create a DistributedSamplerObj")
|
||||
.def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||
uint32_t seed, int64_t offset, bool even_dist) {
|
||||
std::shared_ptr<DistributedSamplerObj> sampler = std::make_shared<DistributedSamplerObj>(
|
||||
num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<PreBuiltSamplerObj, SamplerObj, std::shared_ptr<PreBuiltSamplerObj>>(
|
||||
*m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj")
|
||||
.def(py::init([](int64_t num_samples, py::object sampler) {
|
||||
auto sampler_rt = std::make_shared<PythonSamplerRT>(num_samples, sampler);
|
||||
auto sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler_rt));
|
||||
THROW_IF_ERROR(sampler_obj->ValidateParams());
|
||||
return sampler_obj;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<PKSamplerObj, SamplerObj, std::shared_ptr<PKSamplerObj>>(*m, "PKSamplerObj",
|
||||
"to create a PKSamplerObj")
|
||||
.def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) {
|
||||
std::shared_ptr<PKSamplerObj> sampler =
|
||||
std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RandomSamplerObj, SamplerObj, std::shared_ptr<RandomSamplerObj>>(
|
||||
*m, "RandomSamplerObj", "to create a RandomSamplerObj")
|
||||
.def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) {
|
||||
std::shared_ptr<RandomSamplerObj> sampler =
|
||||
std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SequentialSamplerObj, SamplerObj, std::shared_ptr<SequentialSamplerObj>>(
|
||||
*m, "SequentialSamplerObj", "to create a SequentialSamplerObj")
|
||||
.def(py::init([](int64_t start_index, int64_t num_samples) {
|
||||
std::shared_ptr<SequentialSamplerObj> sampler =
|
||||
std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetSamplerObj, SamplerObj, std::shared_ptr<SubsetSamplerObj>>(
|
||||
*m, "SubsetSamplerObj", "to create a SubsetSamplerObj")
|
||||
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
|
||||
std::shared_ptr<SubsetSamplerObj> sampler =
|
||||
std::make_shared<SubsetSamplerObj>(indices, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) {
|
||||
(void)py::class_<SubsetRandomSamplerObj, SubsetSamplerObj, std::shared_ptr<SubsetRandomSamplerObj>>(
|
||||
*m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj")
|
||||
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
|
||||
std::shared_ptr<SubsetRandomSamplerObj> sampler =
|
||||
std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) {
|
||||
(void)py::class_<WeightedRandomSamplerObj, SamplerObj, std::shared_ptr<WeightedRandomSamplerObj>>(
|
||||
*m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj")
|
||||
.def(py::init([](std::vector<double> weights, int64_t num_samples, bool replacement) {
|
||||
std::shared_ptr<WeightedRandomSamplerObj> sampler =
|
||||
std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
|
||||
THROW_IF_ERROR(sampler->ValidateParams());
|
||||
return sampler;
|
||||
}));
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue