From 80d02d6dcdfe936f60d7a83762292f11a4c59a89 Mon Sep 17 00:00:00 2001 From: Nat Sutyanyong Date: Tue, 3 Nov 2020 13:41:51 -0500 Subject: [PATCH] Add optimizer to IR tree #1 --- .../engine/ir/datasetops/dataset_node.cc | 14 + .../engine/ir/datasetops/dataset_node.h | 19 + .../ccsrc/minddata/dataset/engine/opt/pass.cc | 432 ++++++++++++++++++ .../ccsrc/minddata/dataset/engine/opt/pass.h | 247 ++++++++++ 4 files changed, 712 insertions(+) diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 547e2e47fe..06c19bae9a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -20,6 +20,7 @@ #include #include +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/random.h" namespace mindspore { @@ -227,6 +228,7 @@ std::shared_ptr DatasetNode::SetNumWorkers(int32_t num_workers) { num_workers_ = num_workers; return shared_from_this(); } + DatasetNode::DatasetNode() { // Fetch some default value from config manager std::shared_ptr cfg = GlobalContext::config_manager(); @@ -236,5 +238,17 @@ DatasetNode::DatasetNode() { worker_connector_size_ = cfg->worker_connector_size(); } +// In DFS tree traversal, each node is visited twice. Accept is called on the first visit. +Status DatasetNode::Accept(NodePass *p, bool *modified) { + // This method will only be called if its derived class does not implement one. + return p->Visit(shared_from_this(), modified); +} + +// In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit +// after all child nodes are visited. +Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { + // This method will only be called if its derived class does not implement one. + return p->VisitAfter(shared_from_this(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 0e92b5547d..20d1c26429 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -31,6 +31,7 @@ namespace dataset { class Dataset; class SamplerObj; +class NodePass; #define RETURN_EMPTY_IF_ERROR(_s) \ do { \ @@ -107,6 +108,24 @@ class DatasetNode : public std::enable_shared_from_this { /// \return Shared pointer to the original object std::shared_ptr SetNumWorkers(int32_t num_workers); + /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up + /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node + /// visit on the way back up the tree after its descendants are visited. + /// \notes Subclass needs to override this if it requires special node visit access. + /// Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + virtual Status Accept(NodePass *p, bool *modified); + + /// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited. + /// \notes Subclass needs to override this if it requires special node visit access. + /// Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + virtual Status AcceptAfter(NodePass *p, bool *modified); + protected: std::vector> children; std::shared_ptr parent; diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index a4bcc7cbec..d05fd04636 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -15,6 +15,56 @@ */ #include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/ir/datasetops/batch_node.h" +#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" +#include "minddata/dataset/engine/ir/datasetops/concat_node.h" +#include "minddata/dataset/engine/ir/datasetops/map_node.h" +#include "minddata/dataset/engine/ir/datasetops/project_node.h" +#include "minddata/dataset/engine/ir/datasetops/rename_node.h" +#include "minddata/dataset/engine/ir/datasetops/repeat_node.h" +#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" +#include "minddata/dataset/engine/ir/datasetops/skip_node.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/take_node.h" +#include "minddata/dataset/engine/ir/datasetops/transfer_node.h" +#include "minddata/dataset/engine/ir/datasetops/zip_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/album_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" +#endif +#ifdef ENABLE_PYTHON +#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/random_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" +#endif +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" + +////////////////////////////////// +// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. #include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/build_vocab_op.h" #ifndef ENABLE_ANDROID @@ -57,10 +107,391 @@ #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/take_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h" +////////////////////////////////// namespace mindspore { namespace dataset { +// Driver method for TreePass +Status TreePass::Run(std::shared_ptr root_ir, bool *modified) { return Status::OK(); } + +// Driver method for NodePass +Status NodePass::Run(std::shared_ptr root_ir, bool *modified) { + if (root_ir == nullptr || modified == nullptr) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); + } + if (traversalOrder_ == Order::DFS) { + // DFS + return DFSNodeVisit(root_ir, modified); + } else if (traversalOrder_ == Order::BFS) { + // BFS + return BFSNodeVisit(root_ir, modified); + } + return Status::OK(); +} + +// Helper function to perform DFS visit +Status NodePass::DFSNodeVisit(std::shared_ptr node_ir, bool *modified) { + RETURN_IF_NOT_OK(node_ir->Accept(this, modified)); + for (const auto &c : node_ir->Children()) { + RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); + } + return node_ir->AcceptAfter(this, modified); +} + +// Helper function to perform BFS visit +Status NodePass::BFSNodeVisit(std::shared_ptr node_ir, bool *modified) { + // Initialize bfs queue with root + std::queue> bfsQueue; + bfsQueue.push(node_ir); + + // BFS loop + while (!bfsQueue.empty()) { + // Pop the front of the bfs queue + auto curNode = bfsQueue.front(); + bfsQueue.pop(); + + // Run node pass + RETURN_IF_NOT_OK(curNode->Accept(this, modified)); + + // Push children into bfs queue + for (const auto &c : curNode->Children()) { + bfsQueue.push(c); + } + } + return Status::OK(); +} + +// For datasetops IR +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +#ifndef ENABLE_ANDROID +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +#ifdef ENABLE_PYTHON +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +// For datasetops/source IR +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +#ifndef ENABLE_ANDROID +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +#ifndef ENABLE_ANDROID +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +#ifdef ENABLE_PYTHON +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +#ifndef ENABLE_ANDROID +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +#ifndef ENABLE_ANDROID +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +#ifndef ENABLE_ANDROID +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::Visit(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return Visit(std::static_pointer_cast(node), modified); +} + +Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return VisitAfter(std::static_pointer_cast(node), modified); +} + +////////////////////////////////// +// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // Driver method for TreePass Status TreePass::Run(ExecutionTree *tree, bool *modified) { if (tree == nullptr || modified == nullptr) { @@ -320,5 +751,6 @@ Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { return PreRunOnNode(std::static_pointer_cast(node), modified); } #endif +////////////////////////////////// } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index efb9b6eb04..bebee15519 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -21,10 +21,61 @@ #include #include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { +class BatchNode; +class BucketBatchByLengthNode; +#ifndef ENABLE_ANDROID +class BuildSentenceVocabNode; +#endif +class BuildVocabNode; +class ConcatNode; +class MapNode; +class ProjectNode; +class RenameNode; +class RepeatNode; +class ShuffleNode; +class SkipNode; +#ifdef ENABLE_PYTHON +class SyncWaitNode; +#endif +class TakeNode; +class TransferNode; +class ZipNode; +class AlbumNode; +class CelebANode; +class Cifar100Node; +class Cifar10Node; +#ifndef ENABLE_ANDROID +class CLUENode; +#endif +class CocoNode; +#ifndef ENABLE_ANDROID +class CSVNode; +#endif +#ifdef ENABLE_PYTHON +class GeneratorNode; +#endif +class ImageFolderNode; +class ManifestNode; +#ifndef ENABLE_ANDROID +class MindDataNode; +#endif +class MnistNode; +class RandomNode; +#ifndef ENABLE_ANDROID +class TextFileNode; +#endif +#ifndef ENABLE_ANDROID +class TFRecordNode; +#endif +class VOCNode; + +////////////////////////////////// +// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. class BatchOp; class MapOp; @@ -94,15 +145,24 @@ class FilterOp; class GeneratorOp; #endif +////////////////////////////////// // The base class Pass is the basic unit of tree transformation. // The actual implementation of the passes will be derived from here. class Pass : public std::enable_shared_from_this { public: + // Run the transformation pass against the IR tree. + // @param root_ir - Pointer to the IR tree to be transformed. + // @param modified - Pointer to the modified flag, + virtual Status Run(std::shared_ptr root_ir, bool *modified) = 0; + + ////////////////////////////////// + // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // Run the transformation pass against the execution tree. // @param tree - Pointer to the execution tree to be transformed. // @param modified - Pointer to the modified flag, virtual Status Run(ExecutionTree *tree, bool *modified) = 0; + ////////////////////////////////// virtual ~Pass() = default; }; @@ -110,6 +170,13 @@ class Pass : public std::enable_shared_from_this { // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. class TreePass : public Pass { public: + /// \brief Run the transformation pass against the IR tree. + /// \param[inout] root_ir Pointer to the IR tree to be transformed. + /// \param[inout] modified Indicate if the tree was modified + Status Run(std::shared_ptr root_ir, bool *modified) final; + + ////////////////////////////////// + // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. /// \brief Run the transformation pass against the execution tree. /// \param[inout] tree Pointer to the execution tree to be transformed. /// \param[inout] modified Indicate if the tree was modified @@ -121,6 +188,7 @@ class TreePass : public Pass { /// \param[inout] Indicate of the tree was modified. /// \return Status The error code return virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } + ////////////////////////////////// }; // NodePass is a basic Pass class which performs transformation on Node visiting. @@ -136,6 +204,175 @@ class NodePass : public Pass { ~NodePass() = default; + /// \brief Run the transformation pass against the IR tree + /// \param[inout] root_ir Pointer to the IR tree to be transformed + /// \param[inout] modified Indicator if the tree was changed + Status Run(std::shared_ptr root_ir, bool *modified) final; + + /// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down + /// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all + /// \return Status The error code return + virtual Status Visit(std::shared_ptr node, bool *modified) { return Status::OK(); } + + /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation + /// "modified" flag needs to be set to true if node is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all. + /// \return Status The error code return + virtual Status VisitAfter(std::shared_ptr node, bool *modified) { return Status::OK(); } + + // For datasetops IR + // Visit method to be overridden. + // Note that member template can not be virtual, any node which wants to work with NodePass + // should declare Visit of its own type and override "Accept" from DatasetNode. + virtual Status Visit(std::shared_ptr node, bool *modified); + + // VisitAfter method to be overridden. + // Note that member template can not be virtual, any node which wants to work with NodePass + // should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode. + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + +#ifdef ENABLE_PYTHON + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + // For datasetops/source IR + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + +#ifdef ENABLE_PYTHON + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); +#endif + + virtual Status Visit(std::shared_ptr node, bool *modified); + + virtual Status VisitAfter(std::shared_ptr node, bool *modified); + + ////////////////////////////////// + // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. /// \brief Run the transformation pass against the execution tree /// \param[inout] tree Pointer to the execution tree to be transformed /// \param[inout] modified Indicator if the tree was changed @@ -241,13 +478,23 @@ class NodePass : public Pass { virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); #endif + ////////////////////////////////// private: + // Helper function to perform DFS visit + Status DFSNodeVisit(std::shared_ptr node_ir, bool *modified); + + // Helper function to perform BFS visit + Status BFSNodeVisit(std::shared_ptr node_ir, bool *modified); + + ////////////////////////////////// + // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // Helper function to perform DFS visit Status DFSNodeVisit(std::shared_ptr node, bool *modified); // Helper function to perform BFS visit Status BFSNodeVisit(std::shared_ptr root, bool *modified); + ////////////////////////////////// // Tree traversal order of the NodePass Order traversalOrder_;