|
|
|
@ -62,6 +62,8 @@ class Sampler {
|
|
|
|
|
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
|
|
|
|
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);
|
|
|
|
|
|
|
|
|
|
Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {}
|
|
|
|
|
|
|
|
|
|
// default destructor
|
|
|
|
|
~Sampler() = default;
|
|
|
|
|
|
|
|
|
@ -77,7 +79,7 @@ class Sampler {
|
|
|
|
|
|
|
|
|
|
// for next epoch of sampleIds
|
|
|
|
|
// @return - The error code return
|
|
|
|
|
virtual Status Reset() = 0;
|
|
|
|
|
virtual Status ResetSampler() = 0;
|
|
|
|
|
|
|
|
|
|
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
|
|
|
|
// in the dataset that we can sample from.
|
|
|
|
@ -109,8 +111,16 @@ class Sampler {
|
|
|
|
|
// @return - The error code returned.
|
|
|
|
|
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
|
|
|
|
|
|
|
|
|
|
// A print method typically used for debugging
|
|
|
|
|
// @param out - The output stream to write output to
|
|
|
|
|
// @param show_all - A bool to control if you want to show all info or just a summary
|
|
|
|
|
virtual void Print(std::ostream &out, bool show_all) const;
|
|
|
|
|
|
|
|
|
|
// << Stream output operator overload
|
|
|
|
|
// @notes This allows you to write the debug print info using stream operators
|
|
|
|
|
// @param out - reference to the output stream being overloaded
|
|
|
|
|
// @param sampler - reference to teh sampler to print
|
|
|
|
|
// @return - the output stream must be returned
|
|
|
|
|
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
|
|
|
|
|
sampler.Print(out, false);
|
|
|
|
|
return out;
|
|
|
|
|