You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/frontend/optimizer/pattern.cc

177 lines
5.7 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/pattern.h"
#include "pybind_api/api_register.h"
namespace mindspore {
namespace opt {
namespace python_pass {
int Pattern::g_id_ = 0;
MatchResultPtr Prim::match(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(node)) {
return nullptr;
}
MatchResultPtr res = std::make_shared<MatchResult>();
// iterate over all primitives
for (auto &iter : primitives_) {
if (IsPrimitive(node, iter) || iter->name() == "*") {
matched_prim_ = iter;
res->add_entry(shared_from_base<Prim>(), node);
return res;
}
}
return nullptr;
}
MatchResultPtr Call::match(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node)) {
return nullptr;
}
MatchResultPtr res = std::make_shared<MatchResult>();
// IsPrimitiveCNode
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Check Primitive ValueNode
if (prim_pattern_ != nullptr) {
// Passed in prim_pattern
auto prim_value_res = prim_pattern_->match(cnode->input(0));
if (prim_value_res == nullptr) {
return nullptr;
}
res->merge(prim_value_res);
} else if (prim_ != nullptr) {
// Passed in primitive/primitive str
if (!IsPrimitive(cnode->input(0), prim_)) {
return nullptr;
}
} else {
MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern.";
}
// Check inputs
auto p_inputs_size = inputs_.size();
auto node_inputs_size = cnode->size() - 1;
if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) {
return nullptr;
}
// If inputs is not specified, add node without looking into its inputs
if (p_inputs_size == 0) {
res->add_entry(shared_from_base<Call>(), cnode->input(0));
return res;
}
bool failed = false;
for (std::size_t i = 0; i < node_inputs_size; i++) {
auto pattern = inputs_[i];
auto input = cnode->input(i + 1);
auto input_match_result = pattern->match(input);
if (input_match_result == nullptr) {
failed = true;
break;
}
res->merge(input_match_result);
}
if (!failed) {
res->add_entry(shared_from_base<Call>(), cnode->input(0));
return res;
}
return nullptr;
}
MatchResultPtr OneOf::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) {
auto res = iter->match(node);
if (res != nullptr) {
res->add_entry(shared_from_base<OneOf>(), node);
return res;
}
}
return nullptr;
}
MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) {
auto res = iter->match(node);
if (res != nullptr) {
return nullptr;
}
}
auto res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<NoneOf>(), node);
return res;
}
MatchResultPtr Any::match(const AnfNodePtr &node) {
MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<Any>(), node);
return res;
}
MatchResultPtr Imm::match(const AnfNodePtr &node) {
if (!IsValueNode<Int32Imm>(node)) {
return nullptr;
}
// Check value
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
MS_EXCEPTION_IF_NULL(value_ptr);
if ((int32_t)value_ptr->value() == value_) {
MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<Imm>(), node);
return res;
}
return nullptr;
}
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
auto entry = match_result_.find(pattern);
if (entry == match_result_.end()) {
return nullptr;
}
return entry->second;
}
void MatchResult::merge(const MatchResultPtr &other_result) {
auto other_result_map = other_result->result();
// add/update entries in other_result
for (auto &iter : other_result_map) {
match_result_[iter.first] = iter.second;
}
}
REGISTER_PYBIND_DEFINE(
Pattern, ([](const py::module *m) {
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
(void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
.def(py::init<vector<PrimitivePyPtr>, string>())
.def(py::init<vector<string>, string>());
(void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
.def(py::init<PatternPtr, vector<PatternPtr>>())
.def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
.def(py::init<string, vector<PatternPtr>>());
(void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
.def(py::init<tensor::TensorPtr>());
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
.def(py::init<string, tensor::TensorPtr, bool, bool>());
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
}));
} // namespace python_pass
} // namespace opt
} // namespace mindspore