|
|
|
@ -22,7 +22,10 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
|
|
|
|
#include "minddata/dataset/util/status.h"
|
|
|
|
|
#include "minddata/mindrecord/include/shard_column.h"
|
|
|
|
|
#include "minddata/mindrecord/include/shard_error.h"
|
|
|
|
|
#include "minddata/mindrecord/include/shard_reader.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|
|
|
|
~SamplerObj() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief Pure virtual function for derived class to implement parameters validation
|
|
|
|
|
/// \return bool true if all the parameters are valid
|
|
|
|
|
virtual bool ValidateParams() = 0;
|
|
|
|
|
/// \return The Status code of the function. It returns OK status if parameters are valid.
|
|
|
|
|
virtual Status ValidateParams() = 0;
|
|
|
|
|
|
|
|
|
|
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
|
|
|
|
|
/// \return Shared pointers to the newly created Sampler
|
|
|
|
@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|
|
|
|
/// \return The shard id of the derived sampler
|
|
|
|
|
virtual int64_t ShardId() { return 0; }
|
|
|
|
|
|
|
|
|
|
/// \brief Adds a child to the sampler
|
|
|
|
|
/// \param[in] child The sampler to be added as child
|
|
|
|
|
/// \return the Status code returned
|
|
|
|
|
Status AddChild(std::shared_ptr<SamplerObj> child);
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
|
|
|
|
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
|
|
|
|
/// \return Shared pointers to the newly created Sampler
|
|
|
|
|
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
/// \brief A function that calls build on the children of this sampler
|
|
|
|
|
/// \param[in] sampler The samplerRT object built from this sampler
|
|
|
|
|
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<SamplerObj>> children_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class DistributedSamplerObj;
|
|
|
|
@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj {
|
|
|
|
|
std::shared_ptr<SamplerRT> Build() override;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override {
|
|
|
|
|
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
|
|
|
|
|
even_dist_);
|
|
|
|
|
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
|
|
|
|
|
offset_, even_dist_);
|
|
|
|
|
for (auto child : children_) {
|
|
|
|
|
sampler->AddChild(child);
|
|
|
|
|
}
|
|
|
|
|
return sampler;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
/// \brief Function to get the shard id of sampler
|
|
|
|
|
/// \return The shard id of sampler
|
|
|
|
@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj {
|
|
|
|
|
std::shared_ptr<SamplerRT> Build() override;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override {
|
|
|
|
|
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
|
|
|
|
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
|
|
|
|
for (auto child : children_) {
|
|
|
|
|
sampler->AddChild(child);
|
|
|
|
|
}
|
|
|
|
|
return sampler;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int64_t num_val_;
|
|
|
|
@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override;
|
|
|
|
|
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<SamplerRT> sp_;
|
|
|
|
@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerRT> Build() override;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override {
|
|
|
|
|
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
|
|
|
|
|
for (auto child : children_) {
|
|
|
|
|
sampler->AddChild(child);
|
|
|
|
|
}
|
|
|
|
|
return sampler;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool replacement_;
|
|
|
|
@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj {
|
|
|
|
|
std::shared_ptr<SamplerRT> Build() override;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override {
|
|
|
|
|
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
|
|
|
|
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
|
|
|
|
for (auto child : children_) {
|
|
|
|
|
sampler->AddChild(child);
|
|
|
|
|
}
|
|
|
|
|
return sampler;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int64_t start_index_;
|
|
|
|
@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj {
|
|
|
|
|
std::shared_ptr<SamplerRT> Build() override;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override {
|
|
|
|
|
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
|
|
|
|
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
|
|
|
|
for (auto child : children_) {
|
|
|
|
|
sampler->AddChild(child);
|
|
|
|
|
}
|
|
|
|
|
return sampler;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const std::vector<int64_t> indices_;
|
|
|
|
@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|
|
|
|
std::shared_ptr<SamplerRT> Build() override;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> Copy() override {
|
|
|
|
|
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
|
|
|
|
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
|
|
|
|
for (auto child : children_) {
|
|
|
|
|
sampler->AddChild(child);
|
|
|
|
|
}
|
|
|
|
|
return sampler;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
Status ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const std::vector<double> weights_;
|
|
|
|
|