Migrate 3 pre passes to IR optimizer, namely, cache_error_pass, epoch_injection, and removal_pass

pull/9188/head
Nat Sutyanyong 4 years ago
parent 73c91e05b1
commit d69a29a44e

@ -574,7 +574,7 @@ Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
return pre;
});
RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1));
RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1));
TensorRow row;
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
int64_t row_cnt = 0;

@ -214,7 +214,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// 1. PreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
@ -222,41 +222,44 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// 3. PostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status ExecutionTree::Prepare(int32_t num_epochs) {
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
num_epochs_ = num_epochs;
partially_prepare_ = partial;
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
RETURN_IF_NOT_OK(this->PreAction());
// If optional optimizations are enabled
if (optimize_) {
RETURN_IF_NOT_OK(this->Optimize());
}
// Post optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePostAction());
RETURN_IF_NOT_OK(this->PostAction());
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
// Existing transformation implementation, will be removed later
RETURN_IF_NOT_OK(this->PrepareDeprecated());
return Status::OK();
}
Status ExecutionTree::PrepareTreePreAction() {
Status ExecutionTree::PreAction() {
bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
if (!partially_prepare_) {
#ifndef ENABLE_ANDROID
pre_actions.push_back(std::make_unique<CacheErrorPass>());
#endif
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
#ifndef ENABLE_ANDROID
pre_actions.push_back(std::make_unique<CacheTransformPass>());
pre_actions.push_back(std::make_unique<CacheErrorPass>());
#endif
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
}
// this offers a way to override the preset optimization pass with customized ones
// this is used when certain nodes are removed for tree getters
@ -276,15 +279,17 @@ Status ExecutionTree::PrepareTreePreAction() {
return Status::OK();
}
Status ExecutionTree::PrepareTreePostAction() {
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
Status ExecutionTree::PostAction() {
bool modified = false;
OptPass post_actions;
// Construct pre actions
MS_LOG(INFO) << "Running post pass loops.";
#ifndef ENABLE_ANDROID
// Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind.
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
// This is because Python API binding to TensorOperation is still in progress.
post_actions.push_back(std::make_unique<CacheErrorPass>());
post_actions.push_back(std::make_unique<CacheTransformPass>());
post_actions.push_back(std::make_unique<RepeatPass>());
#endif
@ -340,9 +345,6 @@ Status ExecutionTree::PrepareDeprecated() {
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) {
// execute PreAction
RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction());
// Before going down into children, make any prepare flags updates based on this operator.
uint32_t op_prep_flags = dataset_op->PrepareFlags();
BitSet(&prepare_flags_, op_prep_flags);

@ -169,7 +169,7 @@ class ExecutionTree {
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// 1. PreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
@ -177,20 +177,20 @@ class ExecutionTree {
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// 3. PostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status Prepare(int num_epochs = -1);
Status Prepare(int num_epochs = -1, bool partial = false);
// Compulsory transformation/action pre optimization.
// @return Status - The error code return
Status PrepareTreePreAction();
Status PreAction();
// Compulsory transformation/action post optimization.
// @return Status - The error code return
Status PrepareTreePostAction();
Status PostAction();
// Optimization transformation/action, optional.
// @return Status - The error code return
@ -281,6 +281,7 @@ class ExecutionTree {
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool optimize_; // Flag to enable optional optimizations
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
};
} // namespace dataset
} // namespace mindspore

@ -23,6 +23,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -139,5 +140,16 @@ Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status BatchNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BatchNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status BatchNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BatchNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -74,6 +74,18 @@ class BatchNode : public DatasetNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private:
int32_t batch_size_;
bool drop_remainder_;

@ -46,12 +46,40 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_);
element_length_function_, pad_info_, pad_to_bucket_boundary_,
drop_remainder_);
return node;
}
void BucketBatchByLengthNode::Print(std::ostream &out) const {
out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)";
out << Name() + "(columns:" + PrintColumns(column_names_);
int i = 0;
for (auto it : bucket_boundaries_) {
if (i == 0) {
out << ",bucket_boundaries:{";
}
out << it;
if (i < bucket_boundaries_.size() - 1) {
out << ",";
} else {
out << "}";
}
i++;
}
i = 0;
for (auto it : bucket_batch_sizes_) {
if (i == 0) {
out << ",bucket_batch_sizes:{";
}
out << it;
if (i < bucket_batch_sizes_.size() - 1) {
out << ",";
} else {
out << "}";
}
i++;
}
out << ")";
}
Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {

@ -90,14 +90,14 @@ Status BuildSentenceVocabNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status BuildSentenceVocabNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
}
// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status BuildSentenceVocabNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
}

@ -59,17 +59,17 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private:
std::shared_ptr<SentencePieceVocab> vocab_;

@ -85,14 +85,14 @@ Status BuildVocabNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status BuildVocabNode::Accept(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status BuildVocabNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildVocabNode>(), modified);
}
// Visitor accepting method for NodePass
Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status BuildVocabNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
}

@ -58,17 +58,17 @@ class BuildVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private:
std::shared_ptr<Vocab> vocab_;

@ -39,8 +39,10 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
}
std::shared_ptr<DatasetNode> ConcatNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
// create an empty vector to copy a concat
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>());
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler,
children_flag_and_nums_, children_start_end_index_);
return node;
}
@ -80,14 +82,14 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
return Status::OK();
}
// Visitor accepting method for NodePass
Status ConcatNode::Accept(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status ConcatNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ConcatNode>(), modified);
}
// Visitor accepting method for NodePass
Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status ConcatNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
}

@ -66,17 +66,17 @@ class ConcatNode : public DatasetNode {
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
Status AcceptAfter(IRNodePass *p, bool *modified) override;
};
} // namespace dataset

@ -242,9 +242,27 @@ DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) {
worker_connector_size_ = cfg->worker_connector_size();
}
const bool DatasetNode::IsTree() const {
bool is_tree = true;
if (this->parent_.size() > 1) {
MS_LOG(WARNING) << Name() << " has more than one parent.";
return false;
}
for (const auto &child : children_) {
is_tree = child->IsTree();
if (!is_tree) {
MS_LOG(WARNING) << Name() << " has more than one parent.";
break;
}
}
return is_tree;
}
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
std::shared_ptr<DatasetNode> new_node = this->Copy();
// temporary fix to set the num_workers to the new node.
new_node->SetNumWorkers(this->num_workers_);
for (const auto &child : children_) {
new_node->AddChild(child->DeepCopy());
}
@ -298,12 +316,31 @@ void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
children_.push_back(child);
child->parent_.push_back(this);
} else if (child != nullptr) {
MS_LOG(WARNING) << "DatasetNode::AddChild() failed: " + child->Name() + "'s parent isn't a nullptr.";
MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent";
children_.push_back(child);
child->parent_.push_back(this);
}
}
// Insert a node as a child of this node. This node's children becomes the children of the inserted node.
Status DatasetNode::InsertBelow(std::shared_ptr<DatasetNode> node) {
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer.");
CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children.");
CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent.");
for (auto child : children_) {
node->children_.push_back(child);
child->parent_.clear();
child->parent_.push_back(node.get());
}
// Then establish the new parent-child relationship with the new parent.
children_.clear();
children_.push_back(node);
node->parent_.clear();
node->parent_.push_back(this);
return Status::OK();
}
// Remove this node from its parent. Add the child of this node to its parent.
// for now, this remove is limited to node with a single child or no child
Status DatasetNode::Remove() {
@ -325,14 +362,14 @@ Status DatasetNode::Remove() {
}
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
Status DatasetNode::Accept(NodePass *p, bool *modified) {
Status DatasetNode::Accept(IRNodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->Visit(shared_from_this(), modified);
}
// In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit
// after all child nodes are visited.
Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
Status DatasetNode::AcceptAfter(IRNodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified);
}
@ -369,17 +406,5 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}
}
// Visitor accepting method for NodePass
Status SourceNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<SourceNode>(), modified);
}
// Visitor accepting method for NodePass
Status SourceNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<SourceNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -32,7 +32,7 @@ namespace dataset {
class Dataset;
class SamplerObj;
class NodePass;
class IRNodePass;
class DatasetSizeGetter;
// Names for non-leaf IR node
@ -182,6 +182,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Establish the parent-child relationship between this node and its child.
void AddChild(std::shared_ptr<DatasetNode> child);
/// \brief Insert the input node below this node. This node's children becomes the children of the inserted node.
Status InsertBelow(std::shared_ptr<DatasetNode> node);
/// \brief detach this node from its parent, add its child (if any) to its parent
/// \return error code, return error if node has more than 1 children
Status Remove();
@ -190,6 +193,25 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return True if the data of this node will be cached
const bool IsCached() const { return (cache_ != nullptr); }
/// \brief Check if this node is a tree
/// \return True if the structure is indeed a tree, i.e., no node has more than one parent
const bool IsTree() const;
/// \brief Check if this node is a leaf node.
/// \return True if this is a leaf node.
const bool IsLeaf() const { return children_.empty(); }
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if the dataset represented by this node is a mappable dataset
const bool IsMappable() const { return mappable_; }
/// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes
/// \return True if a cache-enabled operator is an ancestor of this node
const bool IsDescendantOfCache() const { return descendant_of_cache_; }
/// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes
void HasCacheAbove() { descendant_of_cache_ = true; }
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
@ -203,7 +225,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
return std::static_pointer_cast<Derived>(shared_from_this());
}
/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
/// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
/// visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access.
@ -211,15 +233,15 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status Accept(NodePass *p, bool *modified);
virtual Status Accept(IRNodePass *p, bool *modified);
/// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited.
/// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited.
/// \notes Subclass needs to override this if it requires special node visit access.
/// Check "dataset/engine/opt/pass.h" for more details.
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
virtual Status AcceptAfter(NodePass *p, bool *modified);
virtual Status AcceptAfter(IRNodePass *p, bool *modified);
virtual bool IsSizeDefined() { return true; }
@ -235,55 +257,22 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
std::string PrintColumns(const std::vector<std::string> &columns) const;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
void PrintNode(std::ostream &out, int *level) const;
};
// SourceNode represents the leaf nodes of a pipeline where the data is pulled into.
class SourceNode : public DatasetNode {
public:
/// \brief Constructor
SourceNode() : DatasetNode() {}
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {}
/// \brief Destructor
~SourceNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if the dataset represented by this node is a mappable dataset
const bool IsMappable() const { return mappable_; }
protected:
bool mappable_;
bool descendant_of_cache_;
};
// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
class MappableSourceNode : public SourceNode {
class MappableSourceNode : public DatasetNode {
public:
/// \brief Constructor
MappableSourceNode() : SourceNode() { mappable_ = true; }
MappableSourceNode() : DatasetNode() { mappable_ = true; }
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
mappable_ = true;
// Initially set to false, and set to true by the optimizer when conditions are met.
descendant_of_cache_ = false;
}
/// \brief Destructor
@ -295,15 +284,17 @@ class MappableSourceNode : public SourceNode {
};
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
class NonMappableSourceNode : public SourceNode {
class NonMappableSourceNode : public DatasetNode {
public:
/// \brief Constructor
NonMappableSourceNode() : SourceNode() { mappable_ = false; }
NonMappableSourceNode() : DatasetNode() { mappable_ = false; }
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {
mappable_ = false;
// Initially set to false, and set to true by the optimizer when conditions are met.
descendant_of_cache_ = false;
}
/// \brief Destructor
@ -313,34 +304,6 @@ class NonMappableSourceNode : public SourceNode {
/// \return Name of the current node
virtual std::string Name() const = 0;
};
// NonLeafNode represents operations over data in a pipeline.
class NonLeafNode : public DatasetNode {
public:
/// \brief Constructor
NonLeafNode() = default;
/// \brief Destructor
~NonLeafNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
// SinkNode represents the end node of a pipeline where the data is pushed out
class SinkNode : public DatasetNode {
public:
/// \brief Constructor
SinkNode() = default;
/// \brief Destructor
~SinkNode() = default;
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

@ -32,8 +32,9 @@ EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epo
// The root node's parent must set to null pointer.
this->AddChild(child);
}
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_);
auto node = std::make_shared<EpochCtrlNode>(num_epochs_);
return node;
}

@ -29,7 +29,10 @@ namespace dataset {
class EpochCtrlNode : public DatasetNode {
public:
/// \brief Constructor
explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
explicit EpochCtrlNode(int32_t num_epochs) : num_epochs_(num_epochs) {}
/// \brief Constructor
EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
/// \brief Destructor
~EpochCtrlNode() = default;

@ -60,14 +60,14 @@ Status FilterNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status FilterNode::Accept(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status FilterNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<FilterNode>(), modified);
}
// Visitor accepting method for NodePass
Status FilterNode::AcceptAfter(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status FilterNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<FilterNode>(), modified);
}

@ -58,17 +58,17 @@ class FilterNode : public DatasetNode {
bool IsSizeDefined() override { return false; };
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private:
std::shared_ptr<TensorOp> predicate_;

@ -42,14 +42,16 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
}
std::shared_ptr<DatasetNode> MapNode::Copy() {
auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_,
std::vector<std::shared_ptr<TensorOperation>> operations = operations_;
auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_,
callbacks_);
return node;
}
void MapNode::Print(std::ostream &out) const {
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
",<project_cols>" + ",...)";
",<project_cols>" + ",num_tensor_ops:"
<< operations_.size() << ",...)";
}
Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
@ -101,14 +103,14 @@ Status MapNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status MapNode::Accept(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status MapNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<MapNode>(), modified);
}
// Visitor accepting method for NodePass
Status MapNode::AcceptAfter(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status MapNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<MapNode>(), modified);
}

@ -63,17 +63,17 @@ class MapNode : public DatasetNode {
const auto &TensorOperations() const { return operations_; }
auto &TensorOperations() { return operations_; }
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private:
std::vector<std::shared_ptr<TensorOperation>> operations_;

@ -70,14 +70,14 @@ Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
return Status::OK();
}
// Visitor accepting method for NodePass
Status RepeatNode::Accept(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status RepeatNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RepeatNode>(), modified);
}
// Visitor accepting method for NodePass
Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status RepeatNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RepeatNode>(), modified);
}

@ -66,17 +66,17 @@ class RepeatNode : public DatasetNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private:
int32_t repeat_count_;

@ -72,14 +72,14 @@ Status RootNode::ValidateParams() {
return Status::OK();
}
// Visitor accepting method for NodePass
Status RootNode::Accept(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status RootNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RootNode>(), modified);
}
// Visitor accepting method for NodePass
Status RootNode::AcceptAfter(NodePass *p, bool *modified) {
// Visitor accepting method for IRNodePass
Status RootNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RootNode>(), modified);
}

@ -58,17 +58,17 @@ class RootNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
Status Accept(IRNodePass *p, bool *modified) override;
/// \brief Base-class override for accepting NodePass visitor
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
Status AcceptAfter(IRNodePass *p, bool *modified) override;
private:
int32_t num_epochs_;

@ -21,6 +21,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -70,5 +71,16 @@ Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status SkipNode::Accept(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<SkipNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status SkipNode::AcceptAfter(IRNodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<SkipNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save