!1281 Implementation of SplitOp

Merge pull request !1281 from Peilin/splitOp
pull/1281/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2e3d55ed87

@ -364,6 +364,18 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
std::string err_msg = "Error: Shuffle buffer size is missing"; std::string err_msg = "Error: Shuffle buffer size is missing";
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "reshuffle_each_epoch") {
(void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"]));
}
}
}
std::shared_ptr<ShuffleOp> op; std::shared_ptr<ShuffleOp> op;
RETURN_IF_NOT_OK(builder->Build(&op)); RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op; *ptr = op;

@ -51,6 +51,7 @@
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" #include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h" #include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h" #include "dataset/engine/datasetops/source/sampler/python_sampler.h"
@ -425,11 +426,14 @@ void bindSamplerOps(py::module *m) {
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices", [](Sampler &self) { .def("get_indices",
[](Sampler &self) {
py::array ret; py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret; return ret;
}); })
.def("add_child",
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
@ -441,12 +445,16 @@ void bindSamplerOps(py::module *m) {
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle")); .def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler") (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
.def(py::init<bool, int64_t>(), py::arg("replacement"), py::arg("numSamples")) .def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
.def(py::init<bool>(), py::arg("replacement")); py::arg("num_samples"))
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler") (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>()); .def(py::init<>());
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler") (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices")); .def(py::init<std::vector<int64_t>>(), py::arg("indices"));

@ -8,5 +8,6 @@ add_library(engine-datasetops-source-sampler OBJECT
sampler.cc sampler.cc
sequential_sampler.cc sequential_sampler.cc
subset_random_sampler.cc subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc weighted_random_sampler.cc
) )

@ -55,13 +55,27 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
} else if (cnt_ == samples_per_buffer_) { } else if (cnt_ == samples_per_buffer_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else { } else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone); (*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids; std::shared_ptr<Tensor> sample_ids;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_)); RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer()); int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
while (cnt_ < samples_per_buffer_) { while (cnt_ < samples_per_buffer_) {
int64_t next_id = (num_devices_ * (cnt_++) + device_id_) % num_rows_; int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_;
*(id_ptr++) = shuffle_ ? shuffle_vec_[static_cast<size_t>(next_id)] : next_id; if (shuffle_) {
sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
cnt_++;
} }
TensorRow row(1, sample_ids); TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
@ -72,11 +86,29 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
Status DistributedSampler::Reset() { Status DistributedSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
cnt_ = 0; cnt_ = 0;
rnd_.seed(seed_++);
if (shuffle_ == true) { if (shuffle_ == true) {
rnd_.seed(seed_);
seed_++;
std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
} }
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK(); return Status::OK();
} }
void DistributedSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): DistributedSampler\n";
if (show_all) {
out << "seed_: " << seed_ << '\n';
out << "device_id_: " << device_id_ << '\n';
out << "num_devices_: " << num_devices_ << '\n';
out << "shuffle_: " << shuffle_ << '\n';
}
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -48,6 +48,8 @@ class DistributedSampler : public Sampler {
// @return - The error code return // @return - The error code return
Status Reset() override; Status Reset() override;
void Print(std::ostream &out, bool show_all) const override;
private: private:
int64_t cnt_; // number of samples that have already been filled in to buffer int64_t cnt_; // number of samples that have already been filled in to buffer
uint32_t seed_; uint32_t seed_;

@ -38,6 +38,7 @@ Status PKSampler::InitSampler() {
rnd_.seed(seed_++); rnd_.seed(seed_++);
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size()); num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_; samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
num_samples_ = num_pk_samples_;
if (shuffle_ == true) { if (shuffle_ == true) {
std::shuffle(labels_.begin(), labels_.end(), rnd_); std::shuffle(labels_.begin(), labels_.end(), rnd_);
} else { } else {
@ -53,6 +54,10 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
} else if (next_id_ == num_pk_samples_) { } else if (next_id_ == num_pk_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else { } else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids; std::shared_ptr<Tensor> sample_ids;
int64_t last_id = int64_t last_id =
@ -63,8 +68,16 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
int64_t cls_id = next_id_++ / samples_per_class_; int64_t cls_id = next_id_++ / samples_per_class_;
const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]]; const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]];
int64_t rnd_ind = std::uniform_int_distribution<int64_t>(0, samples.size() - 1)(rnd_); int64_t rnd_ind = std::uniform_int_distribution<int64_t>(0, samples.size() - 1)(rnd_);
*(id_ptr++) = samples[rnd_ind]; int64_t sampled_id = samples[rnd_ind];
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
} }
TensorRow row(1, sample_ids); TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
} }
@ -75,6 +88,11 @@ Status PKSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
next_id_ = 0; next_id_ = 0;
rnd_.seed(seed_++); rnd_.seed(seed_++);
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK(); return Status::OK();
} }

@ -27,6 +27,10 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) { if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else { } else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
std::shared_ptr<Tensor> sample_ids; std::shared_ptr<Tensor> sample_ids;
{ {
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
@ -38,6 +42,14 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
py::object py_ret = py_sampler_instance.attr("_get_indices")(); py::object py_ret = py_sampler_instance.attr("_get_indices")();
py::array np_sample_ids = py_ret.cast<py::array>(); py::array np_sample_ids = py_ret.cast<py::array>();
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
if (HasChildSampler()) {
for (auto it = sample_ids->begin<int64_t>(); it != sample_ids->end<int64_t>(); ++it) {
int64_t associated_child_id = 0;
RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id));
*it = associated_child_id;
}
}
} catch (const py::error_already_set &e) { } catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what()); return Status(StatusCode::kPyFuncException, e.what());
} catch (const py::cast_error &e) { } catch (const py::cast_error &e) {
@ -79,6 +91,11 @@ Status PythonSampler::Reset() {
} catch (const py::error_already_set &e) { } catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what()); return Status(StatusCode::kPyFuncException, e.what());
} }
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK(); return Status::OK();
} }
} // namespace dataset } // namespace dataset

@ -14,18 +14,22 @@
* limitations under the License. * limitations under the License.
*/ */
#include "dataset/engine/datasetops/source/sampler/random_sampler.h" #include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include "dataset/util/random.h" #include "dataset/util/random.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
RandomSampler::RandomSampler(bool replacement, int64_t num_samples, int64_t samples_per_buffer) RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
int64_t samples_per_buffer)
: Sampler(samples_per_buffer), : Sampler(samples_per_buffer),
seed_(GetSeed()), seed_(GetSeed()),
replacement_(replacement), replacement_(replacement),
user_num_samples_(num_samples), user_num_samples_(num_samples),
next_id_(0), next_id_(0),
reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {} dist(nullptr) {}
Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
@ -34,13 +38,29 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
} else if (next_id_ == num_samples_) { } else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else { } else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds; std::shared_ptr<Tensor> sampleIds;
int64_t last_id = samples_per_buffer_ + next_id_ > num_samples_ ? num_samples_ : samples_per_buffer_ + next_id_; int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_);
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer()); int64_t *id_ptr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
for (int64_t i = 0; i < (last_id - next_id_); i++) { for (int64_t i = 0; i < (last_id - next_id_); i++) {
*(id_ptr + i) = replacement_ ? (*dist)(rnd_) : shuffled_ids_[static_cast<size_t>(i + next_id_)]; int64_t sampled_id = 0;
if (replacement_) {
sampled_id = (*dist)(rnd_);
} else {
sampled_id = shuffled_ids_[static_cast<size_t>(i + next_id_)];
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*(id_ptr + i) = sampled_id;
} }
next_id_ = last_id; next_id_ = last_id;
TensorRow row(1, sampleIds); TensorRow row(1, sampleIds);
@ -53,7 +73,9 @@ Status RandomSampler::InitSampler() {
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
rnd_.seed(seed_++);
rnd_.seed(seed_);
if (replacement_ == false) { if (replacement_ == false) {
shuffled_ids_.reserve(num_rows_); shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) { for (int64_t i = 0; i < num_rows_; i++) {
@ -69,11 +91,33 @@ Status RandomSampler::InitSampler() {
Status RandomSampler::Reset() { Status RandomSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0; next_id_ = 0;
rnd_.seed(seed_++);
if (replacement_ == false) { if (reshuffle_each_epoch_) {
seed_++;
}
rnd_.seed(seed_);
if (replacement_ == false && reshuffle_each_epoch_) {
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
} }
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK(); return Status::OK();
} }
void RandomSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): RandomSampler\n";
if (show_all) {
out << "user_num_samples_: " << user_num_samples_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
out << "next_id_: " << next_id_ << '\n';
}
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -30,7 +30,8 @@ class RandomSampler : public Sampler {
// @param bool replacement - put he id back / or not after a sample // @param bool replacement - put he id back / or not after a sample
// @param int64_t numSamples - number samples to draw // @param int64_t numSamples - number samples to draw
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(bool replacement = false, int64_t num_samples = std::numeric_limits<int64_t>::max(), explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
int64_t num_samples = std::numeric_limits<int64_t>::max(),
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor. // Destructor.
@ -49,6 +50,8 @@ class RandomSampler : public Sampler {
// @return - The error code return // @return - The error code return
Status Reset() override; Status Reset() override;
virtual void Print(std::ostream &out, bool show_all) const;
private: private:
uint32_t seed_; uint32_t seed_;
bool replacement_; bool replacement_;
@ -57,6 +60,7 @@ class RandomSampler : public Sampler {
int64_t next_id_; int64_t next_id_;
std::mt19937 rnd_; std::mt19937 rnd_;
std::unique_ptr<std::uniform_int_distribution<int64_t>> dist; std::unique_ptr<std::uniform_int_distribution<int64_t>> dist;
bool reshuffle_each_epoch_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -15,18 +15,41 @@
*/ */
#include "dataset/engine/datasetops/source/sampler/sampler.h" #include "dataset/engine/datasetops/source/sampler/sampler.h"
#include <string>
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
Sampler::Sampler(int64_t samples_per_buffer) Sampler::Sampler(int64_t samples_per_buffer)
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
if (HasChildSampler()) {
child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]);
if (!child_sampler) {
std::string err_msg("Cannot handshake, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Handshake and init child first.
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
if (HasChildSampler()) {
int64_t child_num_samples = child_sampler->num_samples();
num_rows_ = child_num_samples;
} else {
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
}
// It's up to the derived class to check the validity of the two args // It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler) // Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
return Status::OK(); return Status::OK();
} }
@ -44,6 +67,15 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
return Status::OK(); return Status::OK();
} }
void Sampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): base\n";
if (show_all) {
out << "num_rows_: " << num_rows_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
}
}
Status Sampler::GetAllIdsThenReset(py::array *data) { Status Sampler::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids; std::shared_ptr<Tensor> sample_ids;
@ -84,5 +116,45 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
num_rows_ = num_rows; num_rows_ = num_rows;
return Status::OK(); return Status::OK();
} }
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
if (child == nullptr) {
return Status::OK();
}
// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child);
if (!sampler) {
std::string err_msg("Cannot add child, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Samplers can have at most 1 child.
if (!child_.empty()) {
std::string err_msg("Cannot add child sampler, this sampler already has a child.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
child_.push_back(child);
// doesn't work, protected?
// child->AddParent(this);
return Status::OK();
}
bool Sampler::HasChildSampler() { return !child_.empty(); }
Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
if (child_ids_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!");
}
TensorRow sample_row;
RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -90,6 +90,8 @@ class Sampler : public DatasetOp {
// setter function for num_samples_ // setter function for num_samples_
Status SetNumSamples(int64_t num_samples); Status SetNumSamples(int64_t num_samples);
int64_t num_samples() { return num_samples_; }
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples // first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds() // @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// @return // @return
@ -114,17 +116,48 @@ class Sampler : public DatasetOp {
// @return - The error code return // @return - The error code return
Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); } Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); }
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
Status AddChild(std::shared_ptr<DatasetOp> child);
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds // @param std::shared_ptr<Tensor>* sampleIds
// @param int64_t numElements - must be a non 0 number // @param int64_t numElements - must be a non 0 number
// @return // @return - The error code returned.
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
void Print(std::ostream &out, bool show_all) const override;
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
sampler.Print(out, false);
return out;
}
// Checks if this sampler has a child sampler.
// @return - tre if there is a child sampler, false otherwise.
bool HasChildSampler();
// Uses id as an index for the list of ids generated by the child sampler, and gets the
// associated id.
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
// @param int64_t id - The id used as an index to get the associated child id.
// @return - The error code returned.
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
protected: protected:
// Number of rows of data from the place this sampler is sampling from. If this sampler
// has a child sampler, num_rows_ is the number of ids the child sampler will
// output. Otherwise, num_rows_ is the number of rows in the dataset.
int64_t num_rows_; int64_t num_rows_;
// Number of ids this sampler will return.
int64_t num_samples_; int64_t num_samples_;
// The max number of ids a DataBuffer returned by this sampler will contain.
int64_t samples_per_buffer_; int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_; std::unique_ptr<ColDescriptor> col_desc_;
std::unique_ptr<DataBuffer> child_ids_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -15,6 +15,7 @@
*/ */
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include <algorithm>
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
@ -27,14 +28,26 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
} else if (next_id_ == num_samples_) { } else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else { } else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds; std::shared_ptr<Tensor> sampleIds;
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_)); RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer()); int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
while (next_id_ < lastId) { while (next_id_ < lastId) {
*(idPtr++) = next_id_++; int64_t sampled_id = next_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*idPtr = sampled_id;
next_id_++;
idPtr++;
} }
TensorRow row(1, sampleIds); TensorRow row(1, sampleIds);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
} }
@ -43,6 +56,10 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
Status SequentialSampler::InitSampler() { Status SequentialSampler::InitSampler() {
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
if (HasChildSampler()) {
num_samples_ = std::min(num_samples_, num_rows_);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK(); return Status::OK();
@ -51,7 +68,15 @@ Status SequentialSampler::InitSampler() {
Status SequentialSampler::Reset() { Status SequentialSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0; next_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK(); return Status::OK();
} }
void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -45,6 +45,8 @@ class SequentialSampler : public Sampler {
// @return - The error code return // @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
void Print(std::ostream &out, bool show_all) const override;
private: private:
int64_t next_id_; int64_t next_id_;
}; };

@ -34,6 +34,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
Status SubsetRandomSampler::InitSampler() { Status SubsetRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
num_samples_ = indices_.size();
// Initialize random generator with seed from config manager // Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed()); rand_gen_.seed(GetSeed());
@ -56,6 +58,10 @@ Status SubsetRandomSampler::Reset() {
rand_gen_.seed(GetSeed()); rand_gen_.seed(GetSeed());
std::shuffle(indices_.begin(), indices_.end(), rand_gen_); std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK(); return Status::OK();
} }
@ -65,6 +71,10 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
if (sample_id_ == indices_.size()) { if (sample_id_ == indices_.size()) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else { } else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> outputIds; std::shared_ptr<Tensor> outputIds;
@ -87,7 +97,14 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
*(id_ptr++) = indices_[sample_id_++]; int64_t sampled_id = indices_[sample_id_];
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
sample_id_++;
} }
// Create a TensorTable from that single tensor and push into DataBuffer // Create a TensorTable from that single tensor and push into DataBuffer

@ -0,0 +1,85 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include <memory>
#include <string>
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
namespace mindspore {
namespace dataset {
// Constructor.
SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
Status SubsetSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
num_samples_ = subset_size_;
return Status::OK();
}
Status SubsetSampler::Reset() {
current_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (current_id_ > subset_size_) {
RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error");
} else if (current_id_ == subset_size_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampled_ids;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_));
int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer());
while (current_id_ < subset_size_) {
int64_t sampled_id = start_index_ + current_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*(sampled_ids_start_addr + current_id_) = sampled_id;
current_id_++;
}
TensorRow sampled_ids_row(1, sampled_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,58 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#include <memory>
#include <vector>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace mindspore {
namespace dataset {
class SubsetSampler : public Sampler {
public:
// Constructor.
// @param start_index The index we start sampling from.
explicit SubsetSampler(int64_t start_index, int64_t subset_size);
// Destructor.
~SubsetSampler() = default;
// Initialize the sampler.
// @return Status
Status InitSampler() override;
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
Status Reset() override;
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
int64_t start_index_;
int64_t subset_size_;
int64_t current_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_

@ -40,6 +40,8 @@ WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights,
Status WeightedRandomSampler::InitSampler() { Status WeightedRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive"); CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
num_samples_ = user_num_samples_;
// Initialize random generator with seed from config manager // Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed()); rand_gen_.seed(GetSeed());
@ -81,6 +83,11 @@ Status WeightedRandomSampler::Reset() {
} else { } else {
discrete_dist_->reset(); discrete_dist_->reset();
} }
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK(); return Status::OK();
} }
@ -98,6 +105,10 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
if (sample_id_ == user_num_samples_) { if (sample_id_ == user_num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else { } else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> outputIds; std::shared_ptr<Tensor> outputIds;
@ -127,7 +138,12 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound)."); RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound).");
} }
*(id_ptr++) = genId; if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
}
*id_ptr = genId;
id_ptr++;
sample_id_++; sample_id_++;
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1031,3 +1031,44 @@ def check_textfiledataset(method):
return method(*args, **kwargs) return method(*args, **kwargs)
return new_method return new_method
def check_split(method):
"""check the input arguments of split."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_list = ['sizes']
nreq_param_bool = ['randomize']
check_param_type(nreq_param_list, param_dict, list)
check_param_type(nreq_param_bool, param_dict, bool)
# check sizes: must be list of float or list of int
sizes = param_dict.get('sizes')
if not sizes:
raise ValueError("sizes cannot be empty.")
all_int = all(isinstance(item, int) for item in sizes)
all_float = all(isinstance(item, float) for item in sizes)
if not (all_int or all_float):
raise ValueError("sizes should be list of int or list of float.")
if all_int:
all_positive = all(item > 0 for item in sizes)
if not all_positive:
raise ValueError("sizes is a list of int, but there should be no negative numbers.")
if all_float:
all_valid_percentages = all(0 < item <= 1 for item in sizes)
if not all_valid_percentages:
raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].")
epsilon = 0.00001
if not abs(sum(sizes) - 1) < epsilon:
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
return method(*args, **kwargs)
return new_method

@ -92,7 +92,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) { TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
uint32_t original_seed = GlobalContext::config_manager()->seed(); uint32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0); GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12); std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)}); auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)});
tree->Prepare(); tree->Prepare();

@ -138,7 +138,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) { TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
int32_t original_seed = GlobalContext::config_manager()->seed(); int32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0); GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12); std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label
std::string folder_path = datasets_root_path_ + "/testPK/data"; std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});

@ -164,9 +164,36 @@ def test_python_sampler():
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_sampler_chain():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_shards, shard_id):
sampler = ds.DistributedSampler(num_shards, shard_id, False)
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)
res = []
for item in data1.create_dict_iterator():
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
.format(item["image"].shape[0], item["label"].item()))
res.append(map[(item["image"].shape[0], item["label"].item())])
return res
assert test_config(2, 0) == [0, 2, 4]
assert test_config(2, 1) == [1, 3, 0]
assert test_config(5, 0) == [0]
assert test_config(5, 1) == [1]
assert test_config(5, 2) == [2]
assert test_config(5, 3) == [3]
assert test_config(5, 4) == [4]
if __name__ == '__main__': if __name__ == '__main__':
test_sequential_sampler(True) test_sequential_sampler(True)
test_random_sampler(True) test_random_sampler(True)
test_random_sampler_multi_iter(True) test_random_sampler_multi_iter(True)
test_sampler_py_api() test_sampler_py_api()
test_python_sampler() test_python_sampler()
test_sampler_chain()

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save