From 6d4c07c886e27312da1588251b1dac303fd439d1 Mon Sep 17 00:00:00 2001 From: BowenK Date: Thu, 30 Jul 2020 13:00:38 +0800 Subject: [PATCH] Update python pattern expression --- mindspore/ccsrc/frontend/optimizer/pattern.cc | 158 ++++++++++++ mindspore/ccsrc/frontend/optimizer/pattern.h | 228 ++++++++++++++++++ mindspore/ccsrc/frontend/optimizer/py_pass.cc | 190 +++++++-------- mindspore/ccsrc/frontend/optimizer/py_pass.h | 11 +- .../frontend/optimizer/py_pass_manager.cc | 2 +- .../frontend/optimizer/py_pass_manager.h | 3 +- mindspore/common/graph_pattern.py | 154 ++++++++++++ mindspore/common/python_pass_register.py | 9 +- tests/ut/python/optimizer/test_python_pass.py | 141 ++++++++++- 9 files changed, 769 insertions(+), 127 deletions(-) create mode 100644 mindspore/ccsrc/frontend/optimizer/pattern.cc create mode 100644 mindspore/ccsrc/frontend/optimizer/pattern.h create mode 100644 mindspore/common/graph_pattern.py diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.cc b/mindspore/ccsrc/frontend/optimizer/pattern.cc new file mode 100644 index 0000000000..412c0bdb46 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pattern.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/pattern.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +int Pattern::g_id_ = 0; + +MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) { + if (!IsValueNode(node)) { + return nullptr; + } + MatchResultPtr res = std::make_shared(); + if (IsValueNode(node)) { + // iterate over all primitives + for (auto &iter : primitives_) { + if (IsPrimitive(node, iter) || iter->name() == "*") { + matched_prim_ = iter; + res->add_entry(shared_from_base(), node); + return res; + } + } + } + return nullptr; +} + +MatchResultPtr CallWith::match(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node)) { + return nullptr; + } + MatchResultPtr res = std::make_shared(); + // IsPrimitiveCNode + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Check Primitive ValueNode + if (prim_pattern_ != nullptr) { + // Passed in prim_pattern + auto prim_value_res = prim_pattern_->match(cnode->input(0)); + if (prim_value_res == nullptr) { + return nullptr; + } + res->merge(prim_value_res); + } else if (prim_ != nullptr) { + // Passed in primitive/primitive str + if (!IsPrimitive(cnode->input(0), prim_)) { + return nullptr; + } + } else { + MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern."; + } + // Check inputs + auto p_inputs_size = inputs_.size(); + auto node_inputs_size = cnode->size() - 1; + if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) { + return nullptr; + } + // If inputs is not specified, add node without looking into its inputs + if (p_inputs_size == 0) { + res->add_entry(shared_from_base(), cnode->input(0)); + return res; + } + bool failed = false; + for (std::size_t i = 0; i < node_inputs_size; i++) { + auto pattern = inputs_[i]; + auto input = cnode->input(i + 1); + auto input_match_result = pattern->match(input); + if (input_match_result == nullptr) { + failed = true; + break; + } + res->merge(input_match_result); + } + if (!failed) { + res->add_entry(shared_from_base(), cnode->input(0)); + return res; + } + return nullptr; +} + +MatchResultPtr IsIn::match(const AnfNodePtr &node) { + for (auto &iter : patterns_) { + auto res = iter->match(node); + if (res != nullptr) { + return res; + } + } + return nullptr; +} + +MatchResultPtr IsNot::match(const AnfNodePtr &node) { + for (auto &iter : patterns_) { + auto res = iter->match(node); + if (res != nullptr) { + return nullptr; + } + } + auto res = std::make_shared(); + res->add_entry(shared_from_base(), node); + return res; +} + +MatchResultPtr AnyPattern::match(const AnfNodePtr &node) { + MatchResultPtr res = std::make_shared(); + res->add_entry(shared_from_base(), node); + return res; +} + +AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) { + auto entry = match_result_.find(pattern); + if (entry == match_result_.end()) { + return nullptr; + } + return entry->second; +} + +void MatchResult::merge(const MatchResultPtr &other_result) { + auto other_result_map = other_result->_result(); + // add/update entries in other_result + for (auto &iter : other_result_map) { + match_result_[iter.first] = iter.second; + } +} + +REGISTER_PYBIND_DEFINE( + Pattern, ([](const py::module *m) { + (void)py::class_>(*m, "Pattern").def(py::init<>()); + (void)py::class_, Pattern>(*m, "IsIn_").def(py::init>()); + (void)py::class_, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr()) + .def(py::init, string, bool>()) + .def(py::init, string, bool>()); + (void)py::class_, Pattern>(*m, "CallWith_") + .def(py::init, bool>()) + .def(py::init, bool>()) + .def(py::init, bool>()); + (void)py::class_, Pattern>(*m, "IsNot_").def(py::init>()); + (void)py::class_, Pattern>(*m, "AnyPattern").def(py::init<>()); + (void)py::class_, Pattern>(*m, "NewTensor_") + .def(py::init()); + })); +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.h b/mindspore/ccsrc/frontend/optimizer/pattern.h new file mode 100644 index 0000000000..8d567e5ab2 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pattern.h @@ -0,0 +1,228 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ +#include +#include +#include +#include + +#include "base/base.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "utils/primitive_py.h" +#include "utils/tensor_py.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +using std::string; +using std::vector; + +class MatchResult; +using MatchResultPtr = std::shared_ptr; +class Pattern; +using PatternPtr = std::shared_ptr; +class IsPrimTypeOf; +using IsPrimTypeOfPtr = std::shared_ptr; +class CallWith; +using CallWithPtr = std::shared_ptr; +class NewTensor; +using NewTensorPtr = std::shared_ptr; +struct PatternHasher; +struct PatternEqual; +using PatternNodeMap = std::unordered_map; + +class Pattern : public Base { + public: + Pattern() : unique_name_(std::to_string(g_id_++)) {} + virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; } + virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; } + string unique_name() const { return unique_name_; } + vector inputs() { return inputs_; } + bool should_replace() { return should_replace_; } + virtual void reset() {} + + protected: + static int g_id_; + // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed + string unique_name_; + vector inputs_; + bool should_replace_ = true; +}; + +struct PatternEqual { + bool operator()(PatternPtr const &p1, PatternPtr const &p2) const { + MS_EXCEPTION_IF_NULL(p1); + MS_EXCEPTION_IF_NULL(p2); + return p1->unique_name() == p2->unique_name(); + } +}; + +struct PatternHasher { + std::size_t operator()(PatternPtr const &p) const { + MS_EXCEPTION_IF_NULL(p); + return std::hash()(p->unique_name()); + } +}; + +class IsPrimTypeOf : public Pattern { + public: + IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); } + IsPrimTypeOf(vector prims, string name, bool should_replace) + : primitives_(prims), name_(name), matched_prim_(nullptr) { + unique_name_ = std::to_string(g_id_++) + "_" + name; + should_replace_ = should_replace; + if (!should_replace) { + matched_prim_ = prims[0]; + } + } + IsPrimTypeOf(vector types, string name, bool should_replace) : types_(types), name_(name) { + unique_name_ = std::to_string(g_id_++) + "_" + name; + // Make primitives_ + for (auto &iter : types) { + primitives_.push_back(std::make_shared(iter, py::cast(nullptr))); + } + should_replace_ = should_replace; + if (!should_replace) { + matched_prim_ = primitives_[0]; + } + } + MS_DECLARE_PARENT(IsPrimTypeOf, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + PrimitivePyPtr matched_primitive() { return matched_prim_; } + void reset() override { + if (should_replace_) { + matched_prim_ = nullptr; + } + } + + private: + vector types_; + vector primitives_; + string name_; + PrimitivePyPtr matched_prim_; +}; + +class CallWith : public Pattern { + public: + CallWith() { unique_name_ = std::to_string(g_id_++); } + CallWith(PatternPtr prim_pattern, vector inputs, bool should_replace) { + // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting + prim_pattern_ = prim_pattern; + unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name(); + inputs_ = inputs; + should_replace_ = should_replace; + } + CallWith(PrimitivePyPtr prim, vector inputs, bool should_replace) { + prim_ = prim; + unique_name_ = std::to_string(g_id_++) + prim_->ToString(); + inputs_ = inputs; + should_replace_ = should_replace; + } + CallWith(string prim_str, vector inputs, bool should_replace) { + prim_ = std::make_shared(prim_str, py::cast(nullptr)); + unique_name_ = std::to_string(g_id_++) + prim_->ToString(); + inputs_ = inputs; + should_replace_ = should_replace; + } + MS_DECLARE_PARENT(CallWith, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + PrimitivePtr prim_value() { return prim_; } + PatternPtr prim_pattern() { return prim_pattern_; } + + private: + PatternPtr prim_pattern_ = nullptr; + PrimitivePtr prim_ = nullptr; + vector types_; + string name_; +}; + +class IsIn : public Pattern { + public: + IsIn() { unique_name_ = std::to_string(g_id_++); } + explicit IsIn(vector patterns) : patterns_(patterns) { + unique_name_ = std::to_string(g_id_++); + for (auto &iter : patterns) { + unique_name_ = unique_name_ + "_" + iter->unique_name(); + } + } + MS_DECLARE_PARENT(IsIn, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + + private: + vector patterns_; +}; + +class IsNot : public Pattern { + public: + IsNot() { unique_name_ = std::to_string(g_id_++); } + explicit IsNot(vector patterns) : patterns_(patterns) { + unique_name_ = std::to_string(g_id_++); + for (auto &iter : patterns) { + unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name(); + } + } + MS_DECLARE_PARENT(IsNot, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; + + private: + vector patterns_; +}; + +class AnyPattern : public Pattern { + public: + AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; } + MS_DECLARE_PARENT(AnyPattern, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override; +}; + +class NewTensor : public Pattern { + public: + NewTensor() { unique_name_ = std::to_string(g_id_++); } + explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; } + MS_DECLARE_PARENT(NewTensor, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override { + MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n"; + } + tensor::TensorPtr input_tensor() { return input_tensor_; } + + private: + tensor::TensorPtr input_tensor_; +}; + +class MatchResult { + public: + MatchResult() {} + void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } + PatternNodeMap _result() { return match_result_; } + AnfNodePtr get_node(const PatternPtr &pattern); + void merge(const MatchResultPtr &other_result); + void clear() { match_result_.clear(); } + void dump() { + MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n"; + for (auto &iter : match_result_) { + MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n"; + } + } + + private: + PatternNodeMap match_result_; +}; +} // namespace python_pass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index 34c46c6b66..362427d227 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -22,6 +22,7 @@ #include "ir/func_graph.h" #include "ir/manager.h" +#include "utils/primitive_py.h" #include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/resource.h" @@ -29,6 +30,8 @@ namespace mindspore { namespace opt { namespace python_pass { namespace internal { +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res); + std::string GetNodeRepr(AnfNodePtr node) { if (node != nullptr) { if (node->isa()) { @@ -50,126 +53,104 @@ std::string GetNodeRepr(AnfNodePtr node) { return ""; } -void ResolveFuncGraph_(const FuncGraphPtr &fg) { - auto manager = Manage(fg, false); - auto use_sig = parse::python_adapter::UseSignatureInResolve(); - parse::python_adapter::set_use_signature_in_resolve(false); - parse::ResolveAll(manager); - parse::python_adapter::set_use_signature_in_resolve(use_sig); -} - -bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { +bool IsTraversable(const AnfNodePtr &node) { if (node == nullptr) { return false; } - MS_EXCEPTION_IF_NULL(pattern); - if (pattern->isa()) { - if (!node->isa()) { - return false; - } - if (GetNodeRepr(pattern) == GetNodeRepr(node)) { - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); - return true; - } - return false; - } else if (pattern->isa()) { - MS_LOG(DEBUG) << pattern->ToString() + "\n"; - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); + if (node->isa() || node->isa()) { return true; - } else if (pattern->isa()) { - // match every single sub ANode - if (!node->isa()) { - return false; - } - auto pattern_inputs = pattern->cast()->inputs(); - auto node_inputs = node->cast()->inputs(); - if (pattern_inputs.size() != node_inputs.size()) { - return false; - } - for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); - p_item++, node_item++) { - auto res = Match(*p_item, *node_item, equiv_ptr); - if (!res) { - return false; - } - } + } + if (IsValueNode(node) || IsValueNode(node)) { return true; } - MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; + return false; } -AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, - const NodeEquivPtr &equiv_ptr) { - if (cur_raw_dst_node_->isa()) { - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; - } else if (cur_raw_dst_node_->isa()) { - // check primitive ValueNode - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast()->value()->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - return cur_raw_dst_node_; - } else if (cur_raw_dst_node_->isa()) { - std::vector new_inputs; - auto inputs = cur_raw_dst_node_->cast()->inputs(); - for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { - auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); - new_inputs.push_back(subed); - } - return func_graph->NewCNode(new_inputs); - } - MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); +AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) { + // Build up AnfNode from primitive + auto prim_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(prim_pattern); + PrimitivePyPtr prim = prim_pattern->matched_primitive(); + MS_EXCEPTION_IF_NULL(prim); + // Make value node out of primitives + return std::make_shared(prim); } -bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->isa() || node->isa()) { - return true; - } - if (IsValueNode(node) || IsValueNode(node)) { - return true; +AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) { + // Build a ValueNode from TensorPtr + auto new_tensor_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(new_tensor_pattern); + auto input_tensor = new_tensor_pattern->input_tensor(); + MS_EXCEPTION_IF_NULL(input_tensor); + return std::make_shared(input_tensor); +} + +AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) { + auto call_with_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(call_with_pattern); + auto prim = call_with_pattern->prim_value(); + if (prim != nullptr) { + return std::make_shared(prim); } - return false; + auto prim_pattern = call_with_pattern->prim_pattern(); + MS_EXCEPTION_IF_NULL(prim_pattern); + return ProcessSinglePattern(prim_pattern, res); } -} // namespace internal -void PythonPass::Build(const py::function &src, const py::function &dst) { - // 1. get FuncGraph from py::function - auto src_fg_ = parse::ParsePythonCode(src); - auto dst_fg_ = parse::ParsePythonCode(dst); - if (src_fg_ == nullptr || dst_fg_ == nullptr) { - MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) { + if (pattern->should_replace()) { + // Find replacement in the MatchResult + auto target_node = res->get_node(pattern); + if (target_node == nullptr) { + MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n"; + } + return target_node; } - // 2. Resolve - internal::ResolveFuncGraph_(src_fg_); - internal::ResolveFuncGraph_(dst_fg_); - // 3. from FuncGraphPtr to ValueNode - src_node_ = src_fg_->output(); - dst_node_ = dst_fg_->output(); + // Build up new node from pattern + if (pattern->isa()) { + return BuildPrimitive(pattern, res); + } else if (pattern->isa()) { + return BuildNewTensor(pattern, res); + } else if (pattern->isa()) { + return BuildPrimitiveValueNode(pattern, res); + } + return nullptr; } -PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, - bool multigraph) - : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { - Build(src, dst); +AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) { + auto target_inputs = pattern->inputs(); + if (target_inputs.size() == 0) { + return ProcessSinglePattern(pattern, res); + } + // Build up the AnfNode in a recursive manner + std::vector new_inputs; + auto prim_value_node = ProcessSinglePattern(pattern, res); + MS_EXCEPTION_IF_NULL(prim_value_node); + new_inputs.push_back(prim_value_node); + for (auto &iter : target_inputs) { + if (iter == pattern) { + MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() + + "\n"; + } + new_inputs.push_back(BuildTarget(iter, func_graph, res)); + } + return func_graph->NewCNode(new_inputs); } +} // namespace internal AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - auto equiv_ptr = std::make_shared(); - bool is_a_match = internal::Match(src_node_, node, equiv_ptr); - if (is_a_match) { - auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); + MS_EXCEPTION_IF_NULL(src_pattern_); + MS_EXCEPTION_IF_NULL(dst_pattern_); + auto res = src_pattern_->match(node); + if (res != nullptr) { + res->dump(); + MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name(); + auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); + dst_pattern_->reset(); MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; return new_node; } + src_pattern_->reset(); return nullptr; } @@ -188,14 +169,12 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) { while (!todo.empty()) { AnfNodePtr node = todo.front(); todo.pop_front(); - - // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { + // Check whether this node has been matched. + if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) { continue; } node->seen_ = seen; - - // select nodes that this transform can be applied. + // Select nodes that this transform can be applied. AnfNodePtr new_node = Run(func_graph, node); bool change = (new_node != nullptr); if (new_node != nullptr && new_node != node) { @@ -206,17 +185,14 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) { if (run_only_once_) { return change; } - - // find success, and add them to todo list + // Find success, and add them to todo list if (IsValueNode(node)) { todo.push_back(GetValueNode(node)->output()); } - if (node->isa()) { auto &inputs = node->cast()->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); } - auto &node_users = manager->node_users(); if (change && node_users.find(node) != node_users.end()) { for (auto &use : node_users[node]) { diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.h b/mindspore/ccsrc/frontend/optimizer/py_pass.h index fc7fc0dda2..022c16a686 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.h @@ -20,6 +20,7 @@ #include #include "ir/anf.h" +#include "frontend/optimizer/pattern.h" #include "pybind_api/api_register.h" #include "pybind_api/export_flags.h" @@ -33,17 +34,17 @@ using NodeEquivPtr = std::shared_ptr; class PythonPass { public: - explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst, - bool run_only_once = false, bool multigraph = true); + explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false, + bool multigraph = true) + : src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {} ~PythonPass() = default; bool Run(const FuncGraphPtr &func_graph); std::string name() const { return name_; } AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); private: - void Build(const py::function &src, const py::function &dst); - AnfNodePtr src_node_ = nullptr; - AnfNodePtr dst_node_ = nullptr; + PatternPtr src_pattern_; + PatternPtr dst_pattern_; const std::string name_; bool run_only_once_; bool multigraph_ = true; diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc index 86d7067d1c..a269788dfe 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -49,7 +49,7 @@ PyPassManager::PyPassManager() { phase_to_group_[Phase::OPT] = std::make_shared(); } -void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, +void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, Phase phase, bool run_only_once, bool multigraph) { auto cur_pm = GetPassGroup(phase); MS_EXCEPTION_IF_NULL(cur_pm); diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h index 1016b81020..3168b433d5 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h @@ -28,6 +28,7 @@ #include "common/utils.h" #include "pipeline/jit/parse/resolve.h" +#include "frontend/optimizer/pattern.h" #include "frontend/optimizer/py_pass.h" #include "frontend/optimizer/pass_group.h" @@ -51,7 +52,7 @@ class PyPassManager { // Access the only global instance static PyPassManagerPtr GetInstance(); virtual ~PyPassManager() = default; - void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); void Unregiste(const std::string &pass_name, Phase phase); PassGroupPtr GetPassGroup(Phase phase); diff --git a/mindspore/common/graph_pattern.py b/mindspore/common/graph_pattern.py new file mode 100644 index 0000000000..487db572f6 --- /dev/null +++ b/mindspore/common/graph_pattern.py @@ -0,0 +1,154 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Patterns for describing graphs""" +from mindspore.ops import Primitive +from mindspore.common.tensor import Tensor +from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_ + +__all__ = [ + "IsIn", + "IsPrimTypeOf", + "CallWith", + "IsNot", + "AnyPattern", + "NewTensor", +] + +class IsIn(IsIn_): + """ + Express a pattern which allows a list of patterns. + """ + def __init__(self, patterns=None, should_replace=True): + r""" + Args: + patterns(list/tuple): list of allowed patterns + should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. + """ + if not should_replace: + raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \ + its sub-pattern instead.") + self.patterns = patterns + if patterns is None: + IsIn_.__init__(self, ()) + elif isinstance(patterns, Pattern): + IsIn_.__init__(self, [patterns]) + elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): + IsIn_.__init__(self, patterns) + else: + raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") + +class IsPrimTypeOf(IsPrimTypeOf_): + r""" + Express a pattern of certain primitive type(s). + NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed, + please refer to CallWith pattern. + """ + def __init__(self, types, name=None, should_replace=True): + r""" + Args: + types (str/(list/tuple of Primitives)): Specify allowed types. + If it is a string, the form could be + 1) a single primitive type, e.g. 'Conv2D' + 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' + It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)] + name (str): name of the pattern, optional + should_replace + """ + if name is not None and not isinstance(name, str): + raise TypeError(f"Expect string, got : {name}") + self.name = name + if isinstance(types, str): + if self.name is None: + self.name = types + self.types = types.split('|') + elif isinstance(types, Primitive): + if self.name is None: + self.name = types.name + self.types = [types] + elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): + if self.name is None: + self.name = "" + for prim in types: + self.name += prim.name + self.types = types + else: + raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") + IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace) + +class CallWith(CallWith_): + r""" + Express a primitive CNode. + """ + def __init__(self, prim_pattern, inputs=None, should_replace=False): + r""" + Args: + prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode. + inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; + if specified, input patterns should be of right order. + """ + if not isinstance(prim_pattern, (Pattern, str, Primitive)): + raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") + self.prim_pattern = prim_pattern + self.inputs = [] + if inputs is None: + pass + elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): + self.inputs = inputs + else: + raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") + CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace) + + +class IsNot(IsNot_): + r""" + Express a pattern which forbids a list of patterns. + NOTE: IsNot pattern should not be the root pattern. + """ + def __init__(self, patterns=None, should_replace=True): + r""" + Args: + patterns(list/tuple): list of forbiden patterns + should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. + """ + if not should_replace: + raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \ + its sub-pattern instead.") + self.patterns = patterns + if patterns is None: + IsNot_.__init__(self, ()) + elif isinstance(patterns, Pattern): + IsNot_.__init__(self, [patterns]) + elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): + IsNot_.__init__(self, patterns) + else: + raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") + +class NewTensor(NewTensor_): + r""" + New Tensor to be used in the target. + """ + def __init__(self, input_tensor, should_replace=False): + r""" + Args: + input_tensor(Tensor): new tensor to be used in the target + should_replace(bool): added this for interface consistency. NewTensor should only appear in the target. + """ + if should_replace: + raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.") + self.input_tensor = input_tensor + if isinstance(input_tensor, Tensor): + NewTensor_.__init__(self, input_tensor) + else: + raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") diff --git a/mindspore/common/python_pass_register.py b/mindspore/common/python_pass_register.py index 36eb37adc7..ee4f0f0bc8 100644 --- a/mindspore/common/python_pass_register.py +++ b/mindspore/common/python_pass_register.py @@ -14,6 +14,7 @@ # ============================================================================ """Python pass register""" from inspect import isfunction +from mindspore.common.graph_pattern import Pattern from mindspore._c_expression import PyPassManager_ from mindspore._c_expression import phase @@ -46,10 +47,10 @@ class PyPassManager(PyPassManager_): raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") pattern, target = py_pass() pass_name = py_pass.__name__ - if not isfunction(pattern): - raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}") - if not isfunction(target): - raise TypeError(f"Expecting function target, got : ({type(target)}){target}") + if not isinstance(pattern, Pattern): + raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}") + if not isinstance(target, Pattern): + raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}") super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) def unregiste(self, py_pass, pipeline_phase=phase.opt): diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py index c3ce3d6c4e..8728120ff1 100644 --- a/tests/ut/python/optimizer/test_python_pass.py +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -22,10 +22,11 @@ from mindspore.ops import operations as P from mindspore.common.python_pass_register import registe_pass, PyPassManager from mindspore.common.api import _generate_pip_args from mindspore._c_expression import generate_key, Executor_ +from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor context.set_context(mode=context.GRAPH_MODE) -def get_func_graph(obj, *args, phase="predict"): +def get_func_graph(obj, *args, phase="validate"): args_names, args_list = _generate_pip_args(obj, *args) dic = dict(zip(args_names, args_list)) key = generate_key(phase, dic) @@ -47,14 +48,11 @@ def test_softmax_relu(): @registe_pass(run_only_once=True) def softmax_relu_pass(): - softmax = P.Softmax() - relu = P.ReLU() - def pattern(x): - x = softmax(x) - return x - def target(x): - x = relu(x) - return x + x = AnyPattern() + softmax_pattern = IsPrimTypeOf(P.Softmax()) + pattern = CallWith(softmax_pattern, inputs=[x]) + relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) + target = CallWith(relu_pattern, inputs=[x]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) @@ -62,3 +60,128 @@ def test_softmax_relu(): ppm.unregiste(softmax_relu_pass) assert "ReLU" in transformed_repr assert "Softmax" not in transformed_repr + +def test_isin_pattern(): + """ + Test IsIn pattern which expresses the IsIn/OneOf semantics. + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_relu_pass(): + x = AnyPattern() + softmax_pattern = IsPrimTypeOf(P.Softmax()) + call_softmax = CallWith(softmax_pattern, inputs=[x]) + relu_pattern = IsPrimTypeOf(P.ReLU()) + call_relu = CallWith(relu_pattern, inputs=[x]) + + pattern = IsIn([call_softmax, call_relu]) + relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False) + target = CallWith(relu6_pattern, inputs=[x]) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) + ppm = PyPassManager() + ppm.unregiste(softmax_relu_pass) + assert "ReLU6" in transformed_repr + assert "Softmax" not in transformed_repr + +def test_isnot_pattern_0(): + """ + Test IsNot pattern which expresses the IsNot semantics. + Case: IsNot pass failed to match + """ + class ConvBN(nn.Cell): + def __init__(self): + super(ConvBN, self).__init__() + self.conv = P.Conv2D(32, 3) + self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32) + self.scale = Tensor(np.ones([32]), mindspore.float32) + self.bias = Tensor(np.ones([32]), mindspore.float32) + self.mean = Tensor(np.ones([32]), mindspore.float32) + self.variance = Tensor(np.ones([32]), mindspore.float32) + self.bn = P.BatchNorm() + def construct(self, x): + x = self.conv(x, self.conv_weight) + x = self.bn(x, self.scale, self.bias, self.mean, self.variance) + return x + inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32) + conv_bn_model = ConvBN() + + @registe_pass(run_only_once=True) + def single_bn_pass(): + """ + Sub a BN which does NOT take Conv as inputs to ReLU6. + """ + conv2d_prim = IsPrimTypeOf("Conv2D") + conv2d = CallWith(conv2d_prim) + pattern_0 = IsNot(conv2d) + pattern = CallWith(P.BatchNorm(), inputs=[pattern_0]) + target = CallWith(P.ReLU6(), inputs=[pattern_0]) + return pattern, target + + @registe_pass(run_only_once=True) + def bn_pass(): + """ + Sub a BN to Softmax. + """ + bn = P.BatchNorm() + pattern = CallWith(bn) + softmax = P.Softmax() + target = CallWith(softmax, should_replace=False) + return pattern, target + + transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) + ppm = PyPassManager() + ppm.unregiste(single_bn_pass) + ppm.unregiste(bn_pass) + assert "ReLU6" not in transformed_repr + assert "Softmax" in transformed_repr + +def test_isnot_pattern_1(): + """ + Test IsNot pattern which expresses the IsNot semantics. + Case: IsNot pattern matches with the graph + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def single_bn_pass(): + """ + Sub a BN which does NOT take MatMul as inputs to ReLU6. + """ + matmul = IsPrimTypeOf("MatMul") + pattern_0 = IsNot(matmul) + softmax = P.Softmax() + pattern = CallWith(softmax, inputs=[pattern_0]) + relu6 = P.ReLU6() + target = CallWith(relu6, inputs=[pattern_0], should_replace=False) + return pattern, target + + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) + ppm = PyPassManager() + ppm.unregiste(single_bn_pass) + assert "ReLU6" in transformed_repr + assert "Softmax" not in transformed_repr + +def test_newtensor_pattern(): + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_addn_pass(): + x = AnyPattern() + softmax = P.Softmax() + pattern = CallWith(softmax, inputs=[x]) + + weight_tensor = Tensor(np.zeros([42]), mindspore.float16) + new_weight = NewTensor(weight_tensor) + addn_ops = P.AddN() + target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) + ppm = PyPassManager() + ppm.unregiste(softmax_addn_pass) + assert "AddN" in transformed_repr + assert "Softmax" not in transformed_repr