From 8eaea74407ba0372b57c76ffff3b894e6c3b49e4 Mon Sep 17 00:00:00 2001 From: Giancarlo Colmenares Date: Wed, 3 Jun 2020 14:32:37 -0400 Subject: [PATCH] Added a Pattern Matcher class to help with future optimization implementations. Includes changes to barnch_culling to show how to use the new Pattern Matcher infrastructure. --- mindspore/ccsrc/ir/pattern_matcher.h | 306 ++++++++++++++++++ .../ccsrc/optimizer/irpass/branch_culling.h | 237 ++++---------- 2 files changed, 377 insertions(+), 166 deletions(-) create mode 100644 mindspore/ccsrc/ir/pattern_matcher.h diff --git a/mindspore/ccsrc/ir/pattern_matcher.h b/mindspore/ccsrc/ir/pattern_matcher.h new file mode 100644 index 0000000000..e955c0afd4 --- /dev/null +++ b/mindspore/ccsrc/ir/pattern_matcher.h @@ -0,0 +1,306 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ +#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ + +#include +#include + +#include "ir/anf.h" +#include "operator/ops.h" + +namespace mindspore { + +/// +/// Base class for all recognizable patterns. +/// We implement an Expression Template approach using static polymorphism based on +/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect +/// to the use of virtual functions without the costs..." as described in: +/// https://en.wikipedia.org/wiki/Expression_templates and +/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern +/// The TryCapture function tries to capture the pattern with the given node. +/// The GetNode function builds a new node using the captured values. +/// + +template +class PBase { + public: + const T &get_object() const { return *static_cast(this); } + + template + bool TryCapture(const TN &value) const { + get_object().Reset(); + return get_object().TryCapture_(value); + } + + using Internal = T; +}; + +template +class PIsEqual { + public: + bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } +}; + +template +class PatternNode : public PBase > { + public: + T GetNode(const AnfNodePtr &node) const { + if (!captured_) { + MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; + } + return captured_node_; + } + + bool TryCapture_(const T &node) const { + if (!captured_) { + captured_node_ = node; + captured_ = true; + return true; + } + return PIsEqual()(captured_node_, node); + } + + void Reset() const { captured_ = false; } + using Internal = const PatternNode &; + + protected: + mutable T captured_node_; + mutable bool captured_{false}; +}; + +template +class PBinOperation : public PBase > { + public: + PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + AnfNodePtr lhs = x_.GetNode(node->func_graph()); + AnfNodePtr rhs = y_.GetNode(node->func_graph()); + AnfNodePtrList list = {prim_->cast(), lhs, rhs}; + return NewCNode(list, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsPrimitiveCNode(node, prim_)) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (inputs.size() == 3) { + // Binary Prim assumes only two inputs + if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { + return false; + } + return true; + } + } + return false; + } + + void Reset() const { + x_.Reset(); + y_.Reset(); + } + + private: + const PrimitivePtr prim_; + typename T::Internal x_; + typename T2::Internal y_; +}; + +/// +/// Helper functions to apply a pattern function on all elements of a tuple +/// +namespace tuple_utils { +template +struct apply_func_tuple_item { + template + static void apply(Func *func, const TTuple &tuple) { + (*func)(Index, std::get(tuple)); + apply_func_tuple_item<(Index + 1) == std::tuple_size::value, (Index + 1), Func>::apply(func, tuple); + } +}; + +template +struct apply_func_tuple_item { + template + static void apply(Func *func, const TTuple &tuple) {} +}; + +template +inline void apply_func_tuple(Func *func, const TTuple &tuple) { + apply_func_tuple_item::value == 0, 0, Func>::apply(func, tuple); +} + +struct PTupleResetCapture { + template + void operator()(size_t i, const T &pattern) const { + pattern.Reset(); + } +}; + +struct PTupleCapture { + explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {} + + template + void operator()(size_t i, const TPattern &pattern) { + // Check if the first node is a Primitive + if (i == 0 && tuple_[i]->isa()) { + auto prim = tuple_[i]->cast(); + if (tuple_[i] != pattern.GetNode(tuple_[i])) { + captured_ = false; + } + } else { + captured_ = captured_ && pattern.TryCapture_(tuple_[i]); + } + } + + const AnfNodePtrList tuple_; + bool captured_{true}; +}; + +struct PTupleGetNode { + explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {} + + template + void operator()(size_t, const TPattern &pattern) { + args_.push_back(pattern.GetNode(node_)); + } + + const AnfNodePtr &node_; + std::vector args_; +}; +} // namespace tuple_utils + +template +class PCNode : public PBase > { + public: + explicit PCNode(const TArgs &... args) : args_(args...) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + tuple_utils::PTupleGetNode get_node(node); + tuple_utils::apply_func_tuple(&get_node, args_); + return NewCNode(get_node.args_, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (node->isa()) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (inputs.size() != sizeof...(TArgs)) { + return false; + } + tuple_utils::PTupleCapture capture_func(inputs); + tuple_utils::apply_func_tuple(&capture_func, args_); + return capture_func.captured_; + } + + return false; + } + + void Reset() const { + tuple_utils::PTupleResetCapture reset; + tuple_utils::apply_func_tuple(&reset, args_); + } + + private: + std::tuple args_; +}; + +template +class PPrimitive : public PBase > { + public: + explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + tuple_utils::PTupleGetNode get_node(node); + tuple_utils::apply_func_tuple(&get_node, args_); + auto prim_cnode = get_node.args_; + prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); + return NewCNode(prim_cnode, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsPrimitiveCNode(node, prim_)) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if ((inputs.size() - 1) != sizeof...(TArgs)) { + return false; + } + + AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); + tuple_utils::PTupleCapture capture_func(rest); + tuple_utils::apply_func_tuple(&capture_func, args_); + + return capture_func.captured_; + } + + return false; + } + + void Reset() const { + tuple_utils::PTupleResetCapture reset; + tuple_utils::apply_func_tuple(&reset, args_); + } + + private: + const PrimitivePtr prim_; + std::tuple args_; +}; + +// Macro for binary operation functions +#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ + template \ + inline PBinOperation Operator(const PBase &x, const PBase &y) { \ + return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ + } + +// Arithmetic operations +BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); +BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); + +// Macros for match and replace +#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ + if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + if ((Condition)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } \ + return (ElseNode).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + return (Lambda)(); \ + } + +#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ + if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ + return (Lambda)(); \ + } + +} // namespace mindspore + +#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/optimizer/irpass/branch_culling.h index b2d6718857..736f67b5dd 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.h @@ -26,141 +26,61 @@ #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "operator/ops.h" +#include "ir/pattern_matcher.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimSwitch, true, X, Y} // {prim::kPrimSwitch, false, X, Y} -class SwitchSimplify : public AnfVisitor { +class SwitchSimplify { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - auto getx = [this](const AnfNodePtr &node) -> bool { - this->x_ = node; - return true; - }; - auto gety = [this](const AnfNodePtr &node) -> bool { - this->y_ = node; - return true; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { + PatternNode cond, true_br, false_br; + auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto cond_value_ = GetValue(GetValueNode(cond.GetNode(node))); + if (cond_value_) { + return true_br.GetNode(node); + } + return false_br.GetNode(node); }; - AnfVisitor::Match(prim::kPrimSwitch, {IsValueNode, getx, gety})(node); - // simplify the switch - if (is_match_) { - if (cond_) { - return x_; - } - return y_; - } + MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, + IsValueNode(cond.GetNode(node))); return nullptr; } - - void Visit(const AnfNodePtr &node) override { - if (!is_match_ && IsValueNode(node)) { - cond_ = GetValue(GetValueNode(node)); - is_match_ = true; - } - } - - void Reset() { - x_ = nullptr; - y_ = nullptr; - cond_ = false; - is_match_ = false; - } - - private: - bool is_match_{false}, cond_{false}; - AnfNodePtr x_{nullptr}, y_{nullptr}; }; // {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => // {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} -class FloatTupleGetItemSwitch : public AnfVisitor { +class FloatTupleGetItemSwitch { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); - - auto fg = node->func_graph(); - if (Xs_.empty() || c_ == nullptr || fg == nullptr) { - return nullptr; - } - - auto true_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[1], c_}); - auto false_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[2], c_}); - - return fg->NewCNode({NewValueNode(prim::kPrimSwitch), Xs_[0], true_node, false_node}); - } - - void Visit(const CNodePtr &cnode) override { - // {prim::kPrimSwith, X1, X2, X3} - if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch) || cnode->size() != 4) { - return; - } - - // copy X1, X2, X3 - auto &inputs = cnode->inputs(); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - } - - void Visit(const ValueNodePtr &vnode) override { c_ = vnode; } - - void Reset() { - Xs_.clear(); - c_ = nullptr; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { + PatternNode cond, true_br, false_br, x; + MATCH_REPLACE_IF(node, + PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), + PPrimitive(prim::kPrimTupleGetItem, false_br, x)), + IsVNode(x.GetNode(node))); + return nullptr; } - - private: - AnfNodePtr c_{nullptr}; - std::vector Xs_{}; }; // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} -class FloatEnvGetItemSwitch : public AnfVisitor { +class FloatEnvGetItemSwitch { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsNode, IsNode})(node); - if (!is_match_) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, X4, X5} - auto cnode = node->cast(); - auto sw_node = cnode->input(1)->cast(); - auto x4 = cnode->input(2); - auto x5 = cnode->input(3); + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { + PatternNode cond, true_br, false_br, x, x2; + MATCH_REPLACE_IF(node, + PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), + PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)), + IsNode(x.GetNode(node)) && IsNode(x2.GetNode(node))); - is_match_ = false; - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsNode, IsNode})(sw_node); - if (!is_match_) { - return nullptr; - } - - // {prim::kPrimSwitch, X1, X2, X3} - auto x1 = sw_node->input(1); - auto x2 = sw_node->input(2); - auto x3 = sw_node->input(3); - - auto fg = node->func_graph(); - if (fg == nullptr) { - return nullptr; - } - - auto true_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x2, x4, x5}); - auto false_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x3, x4, x5}); - - return fg->NewCNode({NewValueNode(prim::kPrimSwitch), x1, true_node, false_node}); + return nullptr; } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; }; namespace internal { @@ -173,79 +93,64 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN } // namespace internal // {{prim::kPrimSwitch, X, G1, G2}, Xs} -class ConvertSwitchReplacement : public AnfVisitor { +class ConvertSwitchReplacement { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } - Reset(); - auto cnode = node->cast(); - if (cnode->size() < 1) { + auto cnode_ = node->cast(); + if (cnode_->size() < 1) { return nullptr; } - // {prim::kPrimSwitch, X, G1, G2} - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(cnode->input(0)); - if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) { - return nullptr; - } - // for switch replace method, only graphs without graph inside can be replaced - for (auto &item : g1_->value_nodes()) { - auto value_node = item.first; - if (IsValueNode(value_node)) { - return nullptr; + auto node_ = cnode_->input(0); + + PatternNode cond, true_br, false_br; + + auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto g1_ = GetValueNode(true_br.GetNode(node_)); + auto g2_ = GetValueNode(false_br.GetNode(node_)); + auto x_ = cond.GetNode(node_); + + // for switch replace method, only graphs without graph inside can be replaced + for (auto &item : g1_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } } - } - for (auto &item : g2_->value_nodes()) { - auto value_node = item.first; - if (IsValueNode(value_node)) { - return nullptr; + for (auto &item : g2_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } } - } - auto true_output = g1_->output()->abstract(); - auto false_output = g2_->output()->abstract(); - auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); - auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); - - std::vector params; - auto fg = node->func_graph(); - auto cloned_g1 = InlineClone(trans_g1, fg, params); - auto cloned_g2 = InlineClone(trans_g2, fg, params); - auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); - return nnode; - } + auto true_output = g1_->output()->abstract(); + auto false_output = g2_->output()->abstract(); + auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); + auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - return; - } - AnfVisitor::Visit(node); - } + std::vector params; + auto fg = node_->func_graph(); + auto cloned_g1 = InlineClone(trans_g1, fg, params); + auto cloned_g2 = InlineClone(trans_g2, fg, params); + auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); - void Visit(const ValueNodePtr &vnode) override { - auto g = GetValueNode(vnode); - if (g1_ == nullptr) { - g1_ = g; - } else { - g2_ = g; - } - } + return nnode; + }; - void Reset() { - x_ = nullptr; - g1_ = nullptr; - g2_ = nullptr; - } + MATCH_REPLACE_LAMBDA_IF(node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, + IsNode(cond.GetNode(node_)) && IsValueNode(true_br.GetNode(node_)) && + IsValueNode(false_br.GetNode(node_))); - private: - AnfNodePtr x_{nullptr}; - FuncGraphPtr g1_{nullptr}, g2_{nullptr}; + return nullptr; + } }; + } // namespace irpass } // namespace opt } // namespace mindspore