parent
4cd6588af0
commit
7120c66998
@ -0,0 +1,148 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor.
|
||||
SubsetSamplerRT::SubsetSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices, int64_t samples_per_buffer)
|
||||
: SamplerRT(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
|
||||
|
||||
// Initialized this Sampler.
|
||||
Status SubsetSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");
|
||||
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// In this case, the id's are provided by the user. Cap the num_samples on the number of id's given.
|
||||
if (num_samples_ == 0 || num_samples_ > static_cast<int64_t>(indices_.size())) {
|
||||
num_samples_ = static_cast<int64_t>(indices_.size());
|
||||
}
|
||||
|
||||
if (samples_per_buffer_ > num_samples_) {
|
||||
samples_per_buffer_ = num_samples_;
|
||||
}
|
||||
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reset the internal variable to the initial state.
|
||||
Status SubsetSamplerRT::ResetSampler() {
|
||||
// Reset the internal counters.
|
||||
sample_id_ = 0;
|
||||
buffer_id_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the sample ids.
|
||||
Status SubsetSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
// All samples have been drawn
|
||||
if (sample_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> outputIds;
|
||||
|
||||
int64_t last_id = sample_id_ + samples_per_buffer_;
|
||||
// Handling the return all samples at once, and when last draw is not a full batch.
|
||||
if (last_id > num_samples_) {
|
||||
last_id = num_samples_;
|
||||
}
|
||||
|
||||
// Allocate tensor
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_));
|
||||
|
||||
// Initialize tensor
|
||||
auto id_ptr = outputIds->begin<int64_t>();
|
||||
while (sample_id_ < last_id) {
|
||||
if (indices_[sample_id_] >= num_rows_ || indices_[sample_id_] < 0) {
|
||||
std::string err_msg = "Sample ID (" + std::to_string(indices_[sample_id_]) +
|
||||
") is out of bound, expected range [0, " + std::to_string(num_rows_ - 1) + "]";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
int64_t sampled_id = ((indices_[sample_id_] % num_rows_) + num_rows_) % num_rows_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*id_ptr = sampled_id;
|
||||
id_ptr++;
|
||||
sample_id_++;
|
||||
}
|
||||
|
||||
// Create a TensorTable from that single tensor and push into DataBuffer
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, TensorRow(1, outputIds)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SubsetSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: SubsetSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
SamplerRT::SamplerPrint(out, show_all);
|
||||
// Then add our own info if any
|
||||
}
|
||||
}
|
||||
|
||||
Status SubsetSamplerRT::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SubsetSampler";
|
||||
args["indices"] = indices_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (this->HasChildSampler()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : child_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t SubsetSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
||||
int64_t child_num_rows = num_rows;
|
||||
if (!child_.empty()) {
|
||||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
}
|
||||
int64_t res = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
res = std::min(res, static_cast<int64_t>(indices_.size()));
|
||||
return res;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// Samples elements from a given list of indices.
|
||||
class SubsetSamplerRT : public SamplerRT {
|
||||
public:
|
||||
/// Constructor.
|
||||
/// \param num_samples The number of elements to sample. 0 for the full amount.
|
||||
/// \param indices List of indices.
|
||||
/// \param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
|
||||
/// When samples_per_buffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
|
||||
SubsetSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices,
|
||||
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
/// Destructor.
|
||||
~SubsetSamplerRT() = default;
|
||||
|
||||
/// Initialize the sampler.
|
||||
/// \return Status
|
||||
Status InitSampler() override;
|
||||
|
||||
/// Reset the internal variable to the initial state and reshuffle the indices.
|
||||
/// \return Status
|
||||
Status ResetSampler() override;
|
||||
|
||||
/// Get the sample ids.
|
||||
/// \param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
/// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
/// Printer for debugging purposes.
|
||||
/// \param out - output stream to write to
|
||||
/// \param show_all - bool to show detailed vs summary
|
||||
void SamplerPrint(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of
|
||||
/// num_samples_
|
||||
/// \param num_rows the size of the dataset this sampler will be applied to.
|
||||
/// \return number of samples
|
||||
int64_t CalculateNumSamples(int64_t num_rows) override;
|
||||
|
||||
protected:
|
||||
/// A list of indices (already randomized in constructor).
|
||||
std::vector<int64_t> indices_;
|
||||
|
||||
private:
|
||||
/// Current sample id.
|
||||
int64_t sample_id_;
|
||||
|
||||
/// Current buffer id.
|
||||
int64_t buffer_id_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
class MindDataTestSubsetSampler : public UT::Common {
|
||||
public:
|
||||
class DummyRandomAccessOp : public RandomAccessOp {
|
||||
public:
|
||||
DummyRandomAccessOp(int64_t num_rows) {
|
||||
num_rows_ = num_rows; // base class
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestSubsetSampler, TestAllAtOnce) {
|
||||
std::vector<int64_t> in({3, 1, 4, 0, 1});
|
||||
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
||||
int64_t num_samples = 0;
|
||||
SubsetSamplerRT sampler(num_samples, in);
|
||||
|
||||
DummyRandomAccessOp dummyRandomAccessOp(5);
|
||||
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<int64_t> out;
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
|
||||
out.push_back(*it);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(in.size(), out.size());
|
||||
for (int i = 0; i < in.size(); i++) {
|
||||
ASSERT_EQ(in[i], out[i]);
|
||||
}
|
||||
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSubsetSampler, TestGetNextBuffer) {
|
||||
int64_t total_samples = 100000 - 5;
|
||||
int64_t samples_per_buffer = 10;
|
||||
int64_t num_samples = 0;
|
||||
std::vector<int64_t> input(total_samples, 1);
|
||||
SubsetSamplerRT sampler(num_samples, input, samples_per_buffer);
|
||||
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<int64_t> out;
|
||||
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
int epoch = 0;
|
||||
while (!db->eoe()) {
|
||||
epoch++;
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
|
||||
out.push_back(*it);
|
||||
}
|
||||
}
|
||||
db.reset();
|
||||
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
}
|
||||
|
||||
ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer);
|
||||
ASSERT_EQ(input.size(), out.size());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestSubsetSampler, TestReset) {
|
||||
std::vector<int64_t> in({0, 1, 2, 3, 4});
|
||||
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
||||
int64_t num_samples = 0;
|
||||
SubsetSamplerRT sampler(num_samples, in);
|
||||
|
||||
DummyRandomAccessOp dummyRandomAccessOp(5);
|
||||
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<int64_t> out;
|
||||
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
|
||||
out.push_back(*it);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(in.size(), out.size());
|
||||
for (int i = 0; i < in.size(); i++) {
|
||||
ASSERT_EQ(in[i], out[i]);
|
||||
}
|
||||
|
||||
sampler.ResetSampler();
|
||||
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), false);
|
||||
db->PopRow(&row);
|
||||
out.clear();
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
|
||||
out.push_back(*it);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(in.size(), out.size());
|
||||
for (int i = 0; i < in.size(); i++) {
|
||||
ASSERT_EQ(in[i], out[i]);
|
||||
}
|
||||
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
Loading…
Reference in new issue