!11784 Change samplers binding in Python to SamplerObj

From: @mahdirahmanihanzaki
Reviewed-by: 
Signed-off-by:
pull/11784/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5f0f9da6c6

@ -7,11 +7,11 @@ if(ENABLE_PYTHON)
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc
python/bindings/dataset/engine/datasetops/source/bindings.cc
python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
python/bindings/dataset/engine/gnn/bindings.cc
python/bindings/dataset/include/datasets_bindings.cc
python/bindings/dataset/include/iterator_bindings.cc
python/bindings/dataset/include/execute_binding.cc
python/bindings/dataset/include/sampler_bindings.cc
python/bindings/dataset/include/schema_bindings.cc
python/bindings/dataset/kernels/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc

@ -1,93 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
.def("set_num_rows",
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples",
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices",
[](SamplerRT &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
THROW_IF_ERROR(self->AddChild(child));
});
}));
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
}));
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
.def(py::init<int64_t, int64_t, bool>());
}));
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
.def(py::init<int64_t, py::object>());
}));
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
.def(py::init<int64_t, bool, bool>());
}));
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
*m, "SequentialSampler")
.def(py::init<int64_t, int64_t>());
}));
PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerRT, SubsetSamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
*m, "SubsetRandomSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SubsetSamplerRT, SamplerRT, std::shared_ptr<SubsetSamplerRT>>(*m, "SubsetSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>());
}));
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) {
(void)py::class_<SamplerObj, std::shared_ptr<SamplerObj>>(*m, "SamplerObj", "to create a SamplerObj")
.def("add_child", [](std::shared_ptr<SamplerObj> self, std::shared_ptr<SamplerObj> child) {
THROW_IF_ERROR(self->AddChildSampler(child));
});
}));
PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<DistributedSamplerObj, SamplerObj, std::shared_ptr<DistributedSamplerObj>>(
*m, "DistributedSamplerObj", "to create a DistributedSamplerObj")
.def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist) {
std::shared_ptr<DistributedSamplerObj> sampler = std::make_shared<DistributedSamplerObj>(
num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<PreBuiltSamplerObj, SamplerObj, std::shared_ptr<PreBuiltSamplerObj>>(
*m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj")
.def(py::init([](int64_t num_samples, py::object sampler) {
auto sampler_rt = std::make_shared<PythonSamplerRT>(num_samples, sampler);
auto sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler_rt));
THROW_IF_ERROR(sampler_obj->ValidateParams());
return sampler_obj;
}));
}));
PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<PKSamplerObj, SamplerObj, std::shared_ptr<PKSamplerObj>>(*m, "PKSamplerObj",
"to create a PKSamplerObj")
.def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) {
std::shared_ptr<PKSamplerObj> sampler =
std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<RandomSamplerObj, SamplerObj, std::shared_ptr<RandomSamplerObj>>(
*m, "RandomSamplerObj", "to create a RandomSamplerObj")
.def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) {
std::shared_ptr<RandomSamplerObj> sampler =
std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<SequentialSamplerObj, SamplerObj, std::shared_ptr<SequentialSamplerObj>>(
*m, "SequentialSamplerObj", "to create a SequentialSamplerObj")
.def(py::init([](int64_t start_index, int64_t num_samples) {
std::shared_ptr<SequentialSamplerObj> sampler =
std::make_shared<SequentialSamplerObj>(start_index, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<SubsetSamplerObj, SamplerObj, std::shared_ptr<SubsetSamplerObj>>(
*m, "SubsetSamplerObj", "to create a SubsetSamplerObj")
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
std::shared_ptr<SubsetSamplerObj> sampler =
std::make_shared<SubsetSamplerObj>(indices, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerObj, SubsetSamplerObj, std::shared_ptr<SubsetRandomSamplerObj>>(
*m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj")
.def(py::init([](std::vector<int64_t> indices, int64_t num_samples) {
std::shared_ptr<SubsetRandomSamplerObj> sampler =
std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerObj, SamplerObj, std::shared_ptr<WeightedRandomSamplerObj>>(
*m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj")
.def(py::init([](std::vector<double> weights, int64_t num_samples, bool replacement) {
std::shared_ptr<WeightedRandomSamplerObj> sampler =
std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
THROW_IF_ERROR(sampler->ValidateParams());
return sampler;
}));
}));
} // namespace dataset
} // namespace mindspore

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -150,15 +150,13 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas
std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) {
// Common Sampler
std::shared_ptr<SamplerRT> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create");
sampler = create().cast<std::shared_ptr<SamplerRT>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse");
sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>();
} else {
// Mindrecord Sampler
std::shared_ptr<mindrecord::ShardOperator> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create_for_minddataset");
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse_for_minddataset");
sampler = parse().cast<std::shared_ptr<mindrecord::ShardOperator>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
}
return sampler_obj;

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -211,6 +211,27 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
}
#endif
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
@ -226,6 +247,25 @@ Status PKSamplerObj::ValidateParams() {
return Status::OK();
}
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
@ -233,6 +273,21 @@ std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
@ -274,24 +329,9 @@ Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
return Status::OK();
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
: replacement_(replacement), num_samples_(num_samples) {}
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
@ -300,10 +340,28 @@ Status RandomSamplerObj::ValidateParams() {
return Status::OK();
}
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object
bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
BuildChildren(sampler);
return sampler;
}
@ -311,7 +369,6 @@ std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
bool reshuffle_each_epoch_ = true;
auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
@ -335,6 +392,24 @@ Status SequentialSamplerObj::ValidateParams() {
return Status::OK();
}
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
@ -378,6 +453,23 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset
return mind_sampler;
}
#endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
@ -399,6 +491,24 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
}
#endif
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
@ -426,6 +536,25 @@ Status WeightedRandomSamplerObj::ValidateParams() {
return Status::OK();
}
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -66,6 +66,8 @@ class SamplerObj {
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
@ -175,6 +177,11 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
/// \brief Function to get the shard id of sampler
@ -211,6 +218,11 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
@ -249,14 +261,14 @@ class PreBuiltSamplerObj : public SamplerObj {
class RandomSamplerObj : public SamplerObj {
public:
RandomSamplerObj(bool replacement, int64_t num_samples);
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
~RandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
@ -267,11 +279,17 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
bool replacement_;
int64_t num_samples_;
bool reshuffle_each_epoch_;
};
class SequentialSamplerObj : public SamplerObj {
@ -294,6 +312,11 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
@ -321,6 +344,11 @@ class SubsetSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
protected:
@ -334,6 +362,8 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
~SubsetRandomSamplerObj() = default;
Status to_json(nlohmann::json *out_json) override;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
@ -367,6 +397,11 @@ class WeightedRandomSamplerObj : public SamplerObj {
return sampler;
}
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:

File diff suppressed because it is too large Load Diff

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -401,20 +401,23 @@ def test_weighted_random_sampler_exception():
weights = (0.9, 0.8, 1.1)
ds.WeightedRandomSampler(weights)
error_msg_3 = "weights size should not be 0"
with pytest.raises(ValueError, match=error_msg_3):
error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
with pytest.raises(RuntimeError, match=error_msg_3):
weights = []
ds.WeightedRandomSampler(weights)
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_4 = "weights should not contain negative numbers"
with pytest.raises(ValueError, match=error_msg_4):
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
with pytest.raises(RuntimeError, match=error_msg_4):
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
ds.WeightedRandomSampler(weights)
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
error_msg_5 = "elements of weights should not be all zero"
with pytest.raises(ValueError, match=error_msg_5):
error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
with pytest.raises(RuntimeError, match=error_msg_5):
weights = [0, 0, 0, 0, 0]
ds.WeightedRandomSampler(weights)
sampler = ds.WeightedRandomSampler(weights)
sampler.parse()
def test_chained_sampler_01():

@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -273,14 +273,14 @@ def test_cv_minddataset_partition_num_samples_equals_0():
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=0)
shard_id=partition_id, num_samples=-1)
num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1
with pytest.raises(Exception) as error_info:
with pytest.raises(ValueError) as error_info:
partitions(5)
try:
assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value)
assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(error_info.value)
except Exception as error:
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -91,23 +91,9 @@ def test_random_sampler_multi_iter(print_res=False):
def test_sampler_py_api():
sampler = ds.SequentialSampler().create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.RandomSampler().create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.DistributedSampler(8, 4).create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.SequentialSampler().parse()
sampler1 = ds.RandomSampler().parse()
sampler1.add_child(sampler)
def test_python_sampler():
@ -158,12 +144,6 @@ def test_python_sampler():
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
test_generator()
sp1 = Sp1().create()
sp1.set_num_rows(5)
sp1.set_num_samples(5)
sp1.initialize()
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_sequential_sampler2():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
@ -229,8 +209,8 @@ def test_subset_sampler():
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]")
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
test_config([0, 9, 3, 2], num_samples=0,
exception_msg="num_samples should be a positive integer value, but got num_samples: 0.")
test_config([0, 9, 3, 2], num_samples=-1,
exception_msg="SubsetRandomSampler: invalid num_samples: -1")
def test_sampler_chain():
@ -280,9 +260,9 @@ def test_add_sampler_invalid_input():
def test_distributed_sampler_invalid_offset():
with pytest.raises(ValueError) as info:
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5)
assert "offset should be no more than num_shards" in str(info.value)
with pytest.raises(RuntimeError) as info:
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse()
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
if __name__ == '__main__':

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -377,7 +377,7 @@ def test_serdes_exception():
def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
"""
Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
after serializing and if it remains the same after repeatly saving and loading.
after serializing and if it remains the same after repeatedly saving and loading.
:param data_orig: original data pipeline to be serialized
:param filename: filename to be saved as json format
:param remove_json_files: whether to remove the json file after testing

Loading…
Cancel
Save