|
|
|
@ -32,7 +32,7 @@ namespace dataset {
|
|
|
|
|
|
|
|
|
|
class Dataset;
|
|
|
|
|
class SamplerObj;
|
|
|
|
|
class NodePass;
|
|
|
|
|
class IRNodePass;
|
|
|
|
|
class DatasetSizeGetter;
|
|
|
|
|
|
|
|
|
|
// Names for non-leaf IR node
|
|
|
|
@ -182,6 +182,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|
|
|
|
/// \brief Establish the parent-child relationship between this node and its child.
|
|
|
|
|
void AddChild(std::shared_ptr<DatasetNode> child);
|
|
|
|
|
|
|
|
|
|
/// \brief Insert the input node below this node. This node's children becomes the children of the inserted node.
|
|
|
|
|
Status InsertBelow(std::shared_ptr<DatasetNode> node);
|
|
|
|
|
|
|
|
|
|
/// \brief detach this node from its parent, add its child (if any) to its parent
|
|
|
|
|
/// \return error code, return error if node has more than 1 children
|
|
|
|
|
Status Remove();
|
|
|
|
@ -190,6 +193,25 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|
|
|
|
/// \return True if the data of this node will be cached
|
|
|
|
|
const bool IsCached() const { return (cache_ != nullptr); }
|
|
|
|
|
|
|
|
|
|
/// \brief Check if this node is a tree
|
|
|
|
|
/// \return True if the structure is indeed a tree, i.e., no node has more than one parent
|
|
|
|
|
const bool IsTree() const;
|
|
|
|
|
|
|
|
|
|
/// \brief Check if this node is a leaf node.
|
|
|
|
|
/// \return True if this is a leaf node.
|
|
|
|
|
const bool IsLeaf() const { return children_.empty(); }
|
|
|
|
|
|
|
|
|
|
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
|
|
|
|
|
/// \return True if the dataset represented by this node is a mappable dataset
|
|
|
|
|
const bool IsMappable() const { return mappable_; }
|
|
|
|
|
|
|
|
|
|
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
|
|
|
|
|
/// \return True if a cache-enabled operator is an ancestor of this node
|
|
|
|
|
const bool IsDescendantOfCache() const { return descendant_of_cache_; }
|
|
|
|
|
|
|
|
|
|
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
|
|
|
|
|
void HasCacheAbove() { descendant_of_cache_ = true; }
|
|
|
|
|
|
|
|
|
|
/// \brief Setter function for runtime number of workers
|
|
|
|
|
/// \param[in] num_workers The number of threads in this operator
|
|
|
|
|
/// \return Shared pointer to the original object
|
|
|
|
@ -203,7 +225,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|
|
|
|
return std::static_pointer_cast<Derived>(shared_from_this());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
|
|
|
|
|
/// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up
|
|
|
|
|
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
|
|
|
|
|
/// visit on the way back up the tree after its descendants are visited.
|
|
|
|
|
/// \notes Subclass needs to override this if it requires special node visit access.
|
|
|
|
@ -211,15 +233,15 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|
|
|
|
/// \param[in] p The node to visit
|
|
|
|
|
/// \param[out] modified Indicator if the node was modified
|
|
|
|
|
/// \return Status of the node visit
|
|
|
|
|
virtual Status Accept(NodePass *p, bool *modified);
|
|
|
|
|
virtual Status Accept(IRNodePass *p, bool *modified);
|
|
|
|
|
|
|
|
|
|
/// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited.
|
|
|
|
|
/// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited.
|
|
|
|
|
/// \notes Subclass needs to override this if it requires special node visit access.
|
|
|
|
|
/// Check "dataset/engine/opt/pass.h" for more details.
|
|
|
|
|
/// \param[in] p The node to visit
|
|
|
|
|
/// \param[out] modified Indicator if the node was modified
|
|
|
|
|
/// \return Status of the node visit
|
|
|
|
|
virtual Status AcceptAfter(NodePass *p, bool *modified);
|
|
|
|
|
virtual Status AcceptAfter(IRNodePass *p, bool *modified);
|
|
|
|
|
|
|
|
|
|
virtual bool IsSizeDefined() { return true; }
|
|
|
|
|
|
|
|
|
@ -235,55 +257,22 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|
|
|
|
std::string PrintColumns(const std::vector<std::string> &columns) const;
|
|
|
|
|
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
|
|
|
|
|
void PrintNode(std::ostream &out, int *level) const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// SourceNode represents the leaf nodes of a pipeline where the data is pulled into.
|
|
|
|
|
class SourceNode : public DatasetNode {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
SourceNode() : DatasetNode() {}
|
|
|
|
|
|
|
|
|
|
/// \brief Constructor that initializes the cache
|
|
|
|
|
/// \param dataset_cache DatasetCache
|
|
|
|
|
explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {}
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
|
~SourceNode() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief Node name getter
|
|
|
|
|
/// \return Name of the current node
|
|
|
|
|
virtual std::string Name() const = 0;
|
|
|
|
|
|
|
|
|
|
/// \brief Base-class override for accepting NodePass visitor
|
|
|
|
|
/// \param[in] p The node to visit
|
|
|
|
|
/// \param[out] modified Indicator if the node was modified
|
|
|
|
|
/// \return Status of the node visit
|
|
|
|
|
Status Accept(NodePass *p, bool *modified) override;
|
|
|
|
|
|
|
|
|
|
/// \brief Base-class override for accepting NodePass visitor
|
|
|
|
|
/// \param[in] p The node to visit
|
|
|
|
|
/// \param[out] modified Indicator if the node was modified
|
|
|
|
|
/// \return Status of the node visit
|
|
|
|
|
Status AcceptAfter(NodePass *p, bool *modified) override;
|
|
|
|
|
|
|
|
|
|
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
|
|
|
|
|
/// \return True if the dataset represented by this node is a mappable dataset
|
|
|
|
|
const bool IsMappable() const { return mappable_; }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
bool mappable_;
|
|
|
|
|
bool descendant_of_cache_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
|
|
|
|
|
class MappableSourceNode : public SourceNode {
|
|
|
|
|
class MappableSourceNode : public DatasetNode {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
MappableSourceNode() : SourceNode() { mappable_ = true; }
|
|
|
|
|
MappableSourceNode() : DatasetNode() { mappable_ = true; }
|
|
|
|
|
|
|
|
|
|
/// \brief Constructor that initializes the cache
|
|
|
|
|
/// \param dataset_cache DatasetCache
|
|
|
|
|
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
|
|
|
|
|
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
|
|
|
|
|
mappable_ = true;
|
|
|
|
|
// Initially set to false, and set to true by the optimizer when conditions are met.
|
|
|
|
|
descendant_of_cache_ = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
@ -295,15 +284,17 @@ class MappableSourceNode : public SourceNode {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
|
|
|
|
|
class NonMappableSourceNode : public SourceNode {
|
|
|
|
|
class NonMappableSourceNode : public DatasetNode {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
NonMappableSourceNode() : SourceNode() { mappable_ = false; }
|
|
|
|
|
NonMappableSourceNode() : DatasetNode() { mappable_ = false; }
|
|
|
|
|
|
|
|
|
|
/// \brief Constructor that initializes the cache
|
|
|
|
|
/// \param dataset_cache DatasetCache
|
|
|
|
|
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
|
|
|
|
|
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
|
|
|
|
|
mappable_ = false;
|
|
|
|
|
// Initially set to false, and set to true by the optimizer when conditions are met.
|
|
|
|
|
descendant_of_cache_ = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
@ -313,34 +304,6 @@ class NonMappableSourceNode : public SourceNode {
|
|
|
|
|
/// \return Name of the current node
|
|
|
|
|
virtual std::string Name() const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// NonLeafNode represents operations over data in a pipeline.
|
|
|
|
|
class NonLeafNode : public DatasetNode {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
NonLeafNode() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
|
~NonLeafNode() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief Node name getter
|
|
|
|
|
/// \return Name of the current node
|
|
|
|
|
virtual std::string Name() const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// SinkNode represents the end node of a pipeline where the data is pushed out
|
|
|
|
|
class SinkNode : public DatasetNode {
|
|
|
|
|
public:
|
|
|
|
|
/// \brief Constructor
|
|
|
|
|
SinkNode() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief Destructor
|
|
|
|
|
~SinkNode() = default;
|
|
|
|
|
|
|
|
|
|
/// \brief Node name getter
|
|
|
|
|
/// \return Name of the current node
|
|
|
|
|
virtual std::string Name() const = 0;
|
|
|
|
|
};
|
|
|
|
|
} // namespace dataset
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
|
|
|
|
|