|
|
|
@ -26,6 +26,7 @@
|
|
|
|
|
#include "minddata/dataset/core/tensor.h"
|
|
|
|
|
#include "minddata/dataset/engine/data_schema.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
|
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
|
|
|
|
#include "minddata/dataset/util/wait_post.h"
|
|
|
|
|
#include "pybind11/pybind11.h"
|
|
|
|
|
|
|
|
|
@ -35,47 +36,47 @@ namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
|
#pragma GCC visibility push(hidden)
|
|
|
|
|
|
|
|
|
|
class GeneratorOp : public PipelineOp {
|
|
|
|
|
class GeneratorOp : public PipelineOp, public RandomAccessOp {
|
|
|
|
|
public:
|
|
|
|
|
class Builder {
|
|
|
|
|
public:
|
|
|
|
|
// Builder constructor. Creates the builder object.
|
|
|
|
|
// @note No default args
|
|
|
|
|
// @return This is a constructor.
|
|
|
|
|
/// Builder constructor. Creates the builder object.
|
|
|
|
|
/// \note No default args
|
|
|
|
|
/// \return This is a constructor.
|
|
|
|
|
Builder();
|
|
|
|
|
|
|
|
|
|
~Builder() = default;
|
|
|
|
|
|
|
|
|
|
// Setter method.
|
|
|
|
|
// @return Builder setter method returns reference to the builder.
|
|
|
|
|
/// Setter method.
|
|
|
|
|
/// \return Builder setter method returns reference to the builder.
|
|
|
|
|
Builder &SetGeneratorFunction(py::function generator_function) {
|
|
|
|
|
build_generator_function_ = generator_function;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Setter method.
|
|
|
|
|
// @return Builder setter method returns reference to the builder.
|
|
|
|
|
/// Setter method.
|
|
|
|
|
/// \return Builder setter method returns reference to the builder.
|
|
|
|
|
Builder &SetColumnNames(const std::vector<std::string> &column_names) {
|
|
|
|
|
build_column_names_ = column_names;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Setter method.
|
|
|
|
|
// @return Builder setter method returns reference to the builder.
|
|
|
|
|
/// Setter method.
|
|
|
|
|
/// \return Builder setter method returns reference to the builder.
|
|
|
|
|
Builder &SetColumnTypes(const std::vector<DataType> &column_types) {
|
|
|
|
|
build_column_types_ = column_types;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Setter method.
|
|
|
|
|
// @return Builder setter method returns reference to the builder.
|
|
|
|
|
/// Setter method.
|
|
|
|
|
/// \return Builder setter method returns reference to the builder.
|
|
|
|
|
Builder &SetPrefetchSize(int32_t prefetch_size) {
|
|
|
|
|
build_prefetch_size_ = prefetch_size;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The builder "build" method creates the final object.
|
|
|
|
|
// @return shared_ptr to the new GeneratorOp object
|
|
|
|
|
/// The builder "build" method creates the final object.
|
|
|
|
|
/// \return shared_ptr to the new GeneratorOp object
|
|
|
|
|
Status Build(std::shared_ptr<GeneratorOp> *);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -94,56 +95,53 @@ class GeneratorOp : public PipelineOp {
|
|
|
|
|
|
|
|
|
|
GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
|
|
|
|
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size,
|
|
|
|
|
int64_t pre_counter_size = 0);
|
|
|
|
|
std::shared_ptr<SamplerRT> sampler);
|
|
|
|
|
|
|
|
|
|
~GeneratorOp();
|
|
|
|
|
~GeneratorOp() = default;
|
|
|
|
|
|
|
|
|
|
// A print method typically used for debugging
|
|
|
|
|
// @param out - The output stream to write output to
|
|
|
|
|
// @param show_all - A bool to control if you want to show all info or just a summary
|
|
|
|
|
/// A print method typically used for debugging
|
|
|
|
|
/// \param out - The output stream to write output to
|
|
|
|
|
/// \param show_all - A bool to control if you want to show all info or just a summary
|
|
|
|
|
void Print(std::ostream &out, bool show_all) const override;
|
|
|
|
|
|
|
|
|
|
// << Stream output operator overload
|
|
|
|
|
// @notes This allows you to write the debug print info using stream operators
|
|
|
|
|
// @param out - reference to the output stream being overloaded
|
|
|
|
|
// @param generator_op - reference to the GeneratorOp to display
|
|
|
|
|
// @return - the output stream must be returned
|
|
|
|
|
/// << Stream output operator overload
|
|
|
|
|
/// \notes This allows you to write the debug print info using stream operators
|
|
|
|
|
/// \param out - reference to the output stream being overloaded
|
|
|
|
|
/// \param generator_op - reference to the GeneratorOp to display
|
|
|
|
|
/// \return - the output stream must be returned
|
|
|
|
|
friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) {
|
|
|
|
|
generator_op.Print(out, false);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Class functor operator () override.
|
|
|
|
|
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
|
|
|
|
// provide the master loop that drives the logic for performing the work.
|
|
|
|
|
// @return Status The status code returned
|
|
|
|
|
/// Class functor operator () override.
|
|
|
|
|
/// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
|
|
|
|
/// provide the master loop that drives the logic for performing the work.
|
|
|
|
|
/// \return Status The status code returned
|
|
|
|
|
Status operator()() override;
|
|
|
|
|
|
|
|
|
|
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
|
|
|
|
// info from it's previous execution and then initializes itself so that it can be executed
|
|
|
|
|
// again.
|
|
|
|
|
// @return Status The status code returned
|
|
|
|
|
/// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
|
|
|
|
/// info from it's previous execution and then initializes itself so that it can be executed
|
|
|
|
|
/// again.
|
|
|
|
|
/// \return Status The status code returned
|
|
|
|
|
Status Reset() override;
|
|
|
|
|
|
|
|
|
|
// Base-class override for NodePass visitor acceptor.
|
|
|
|
|
// @param p - Pointer to the NodePass to be accepted.
|
|
|
|
|
// @param modified - Whether this node visit modified the pipeline.
|
|
|
|
|
// @return - Status of the node visit.
|
|
|
|
|
/// Base-class override for NodePass visitor acceptor.
|
|
|
|
|
/// \param p - Pointer to the NodePass to be accepted.
|
|
|
|
|
/// \param modified - Whether this node visit modified the pipeline.
|
|
|
|
|
/// \return - Status of the node visit.
|
|
|
|
|
Status Accept(NodePass *p, bool *const modified) override;
|
|
|
|
|
|
|
|
|
|
// Op name getter
|
|
|
|
|
// @return Name of the current Op
|
|
|
|
|
/// Op name getter
|
|
|
|
|
/// \return Name of the current Op
|
|
|
|
|
std::string Name() const override { return "GeneratorOp"; }
|
|
|
|
|
|
|
|
|
|
Status Init();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
py::function generator_function_;
|
|
|
|
|
std::vector<std::string> column_names_;
|
|
|
|
|
std::vector<DataType> column_types_;
|
|
|
|
|
int32_t prefetch_size_;
|
|
|
|
|
int32_t buffer_size_;
|
|
|
|
|
int64_t pre_counter_size_;
|
|
|
|
|
int64_t generator_counter_;
|
|
|
|
|
|
|
|
|
|
py::object generator_;
|
|
|
|
@ -151,15 +149,25 @@ class GeneratorOp : public PipelineOp {
|
|
|
|
|
|
|
|
|
|
WaitPost wp_;
|
|
|
|
|
|
|
|
|
|
void Dealloc() noexcept;
|
|
|
|
|
|
|
|
|
|
Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row);
|
|
|
|
|
|
|
|
|
|
Status FillBuffer(TensorQTable *tt);
|
|
|
|
|
|
|
|
|
|
// Private function for computing the assignment of the column name map.
|
|
|
|
|
// @return - Status
|
|
|
|
|
/// Private function for computing the assignment of the column name map.
|
|
|
|
|
/// \return - Status
|
|
|
|
|
Status ComputeColMap() override;
|
|
|
|
|
|
|
|
|
|
/// Initialize Sampler, calls sampler->Init() within
|
|
|
|
|
/// \return Status The status code returned
|
|
|
|
|
Status InitSampler();
|
|
|
|
|
|
|
|
|
|
/// Create new Generator object from the generator function
|
|
|
|
|
/// \return Status The status code returned
|
|
|
|
|
Status CreateGeneratorObject();
|
|
|
|
|
|
|
|
|
|
/// Initialize GeneratorOp
|
|
|
|
|
/// \return Status The status code returned
|
|
|
|
|
Status Init();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#pragma GCC visibility pop
|
|
|
|
|