Changed SamplerObj validate params to return status and added AddChild to it

pull/9687/head
Mahdi 4 years ago
parent 6b5626634c
commit 0f2b5d8cac

File diff suppressed because it is too large Load Diff

@ -37,6 +37,9 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev,
non_empty_(true) {}
Status DistributedSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@ -72,6 +75,7 @@ Status DistributedSamplerRT::InitSampler() {
}
if (!samples_per_buffer_) non_empty_ = false;
is_initialized = true;
return Status::OK();
}

@ -28,6 +28,9 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t
samples_per_class_(val) {}
Status PKSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
labels_.reserve(label_to_ids_.size());
for (const auto &pair : label_to_ids_) {
if (!pair.second.empty()) {
@ -58,6 +61,7 @@ Status PKSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(
num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " +
std::to_string(num_samples_));
is_initialized = true;
return Status::OK();
}

@ -65,6 +65,9 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status PythonSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_));
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
@ -83,6 +86,8 @@ Status PythonSamplerRT::InitSampler() {
return Status(StatusCode::kPyFuncException, e.what());
}
}
is_initialized = true;
return Status::OK();
}

@ -69,6 +69,9 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status RandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@ -91,6 +94,7 @@ Status RandomSamplerRT::InitSampler() {
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
}
is_initialized = true;
return Status::OK();
}

@ -34,7 +34,11 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
}
SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer)
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
: num_rows_(0),
num_samples_(num_samples),
samples_per_buffer_(samples_per_buffer),
col_desc_(nullptr),
is_initialized(false) {}
Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<SamplerRT> child_sampler;

@ -160,6 +160,7 @@ class SamplerRT {
// amount.
int64_t num_samples_;
bool is_initialized;
int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes

@ -63,6 +63,9 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
}
Status SequentialSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
"Invalid parameter, start_index must be greater than or equal to 0, but got " +
std::to_string(start_index_) + ".\n");
@ -82,6 +85,8 @@ Status SequentialSamplerRT::InitSampler() {
num_samples_ > 0 && samples_per_buffer_ > 0,
"Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_));
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
is_initialized = true;
return Status::OK();
}

@ -32,6 +32,9 @@ SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vec
// Initialized this Sampler.
Status SubsetRandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");
@ -51,6 +54,7 @@ Status SubsetRandomSamplerRT::InitSampler() {
// We will shuffle the full set of id's, but only select the first num_samples_ of them later.
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
is_initialized = true;
return Status::OK();
}

@ -37,6 +37,9 @@ WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std:
// Initialized this Sampler.
Status WeightedRandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@ -75,6 +78,7 @@ Status WeightedRandomSamplerRT::InitSampler() {
discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
}
is_initialized = true;
return Status::OK();
}

@ -22,7 +22,10 @@
#include <vector>
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_column.h"
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_reader.h"
#endif
namespace mindspore {
@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
~SamplerObj() = default;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return bool true if all the parameters are valid
virtual bool ValidateParams() = 0;
/// \return The Status code of the function. It returns OK status if parameters are valid.
virtual Status ValidateParams() = 0;
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }
/// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child
/// \return the Status code returned
Status AddChild(std::shared_ptr<SamplerObj> child);
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif
protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
std::vector<std::shared_ptr<SamplerObj>> children_;
};
class DistributedSamplerObj;
@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
even_dist_);
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
int64_t num_val_;
@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj {
std::shared_ptr<SamplerObj> Copy() override;
bool ValidateParams() override;
Status ValidateParams() override;
private:
std::shared_ptr<SamplerRT> sp_;
@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }
std::shared_ptr<SamplerObj> Copy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
bool replacement_;
@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
int64_t start_index_;
@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
const std::vector<int64_t> indices_;
@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
bool ValidateParams() override;
Status ValidateParams() override;
private:
const std::vector<double> weights_;

@ -208,6 +208,37 @@ TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSamplerAddChild) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSamplerAddChild.";
auto sampler = DistributedSampler(1, 0, false, 5, 0, -1, true);
EXPECT_NE(sampler, nullptr);
auto child_sampler = SequentialSampler();
sampler->AddChild(child_sampler);
EXPECT_NE(child_sampler, nullptr);
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
EXPECT_NE(ds, nullptr);
// Iterate the dataset and get each row
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}
EXPECT_EQ(ds->GetDatasetSize(), 5);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail.";
// Test invalid offset setting of distributed_sampler

Loading…
Cancel
Save