parent
c700fc5515
commit
6d4c07c886
@ -0,0 +1,158 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "frontend/optimizer/pattern.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace python_pass {
|
||||
int Pattern::g_id_ = 0;
|
||||
|
||||
MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) {
|
||||
if (!IsValueNode<Primitive>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
// iterate over all primitives
|
||||
for (auto &iter : primitives_) {
|
||||
if (IsPrimitive(node, iter) || iter->name() == "*") {
|
||||
matched_prim_ = iter;
|
||||
res->add_entry(shared_from_base<IsPrimTypeOf>(), node);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
// IsPrimitiveCNode
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check Primitive ValueNode
|
||||
if (prim_pattern_ != nullptr) {
|
||||
// Passed in prim_pattern
|
||||
auto prim_value_res = prim_pattern_->match(cnode->input(0));
|
||||
if (prim_value_res == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
res->merge(prim_value_res);
|
||||
} else if (prim_ != nullptr) {
|
||||
// Passed in primitive/primitive str
|
||||
if (!IsPrimitive(cnode->input(0), prim_)) {
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern.";
|
||||
}
|
||||
// Check inputs
|
||||
auto p_inputs_size = inputs_.size();
|
||||
auto node_inputs_size = cnode->size() - 1;
|
||||
if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) {
|
||||
return nullptr;
|
||||
}
|
||||
// If inputs is not specified, add node without looking into its inputs
|
||||
if (p_inputs_size == 0) {
|
||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
||||
return res;
|
||||
}
|
||||
bool failed = false;
|
||||
for (std::size_t i = 0; i < node_inputs_size; i++) {
|
||||
auto pattern = inputs_[i];
|
||||
auto input = cnode->input(i + 1);
|
||||
auto input_match_result = pattern->match(input);
|
||||
if (input_match_result == nullptr) {
|
||||
failed = true;
|
||||
break;
|
||||
}
|
||||
res->merge(input_match_result);
|
||||
}
|
||||
if (!failed) {
|
||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
||||
return res;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr IsIn::match(const AnfNodePtr &node) {
|
||||
for (auto &iter : patterns_) {
|
||||
auto res = iter->match(node);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr IsNot::match(const AnfNodePtr &node) {
|
||||
for (auto &iter : patterns_) {
|
||||
auto res = iter->match(node);
|
||||
if (res != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<IsNot>(), node);
|
||||
return res;
|
||||
}
|
||||
|
||||
MatchResultPtr AnyPattern::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<AnyPattern>(), node);
|
||||
return res;
|
||||
}
|
||||
|
||||
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
|
||||
auto entry = match_result_.find(pattern);
|
||||
if (entry == match_result_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return entry->second;
|
||||
}
|
||||
|
||||
void MatchResult::merge(const MatchResultPtr &other_result) {
|
||||
auto other_result_map = other_result->_result();
|
||||
// add/update entries in other_result
|
||||
for (auto &iter : other_result_map) {
|
||||
match_result_[iter.first] = iter.second;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
Pattern, ([](const py::module *m) {
|
||||
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
|
||||
(void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr())
|
||||
.def(py::init<vector<PrimitivePyPtr>, string, bool>())
|
||||
.def(py::init<vector<string>, string, bool>());
|
||||
(void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_")
|
||||
.def(py::init<PatternPtr, vector<PatternPtr>, bool>())
|
||||
.def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>())
|
||||
.def(py::init<string, vector<PatternPtr>, bool>());
|
||||
(void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
|
||||
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
|
||||
.def(py::init<tensor::TensorPtr>());
|
||||
}));
|
||||
} // namespace python_pass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,228 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/primitive_py.h"
|
||||
#include "utils/tensor_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace python_pass {
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
class MatchResult;
|
||||
using MatchResultPtr = std::shared_ptr<MatchResult>;
|
||||
class Pattern;
|
||||
using PatternPtr = std::shared_ptr<Pattern>;
|
||||
class IsPrimTypeOf;
|
||||
using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>;
|
||||
class CallWith;
|
||||
using CallWithPtr = std::shared_ptr<CallWith>;
|
||||
class NewTensor;
|
||||
using NewTensorPtr = std::shared_ptr<NewTensor>;
|
||||
struct PatternHasher;
|
||||
struct PatternEqual;
|
||||
using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
|
||||
|
||||
class Pattern : public Base {
|
||||
public:
|
||||
Pattern() : unique_name_(std::to_string(g_id_++)) {}
|
||||
virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; }
|
||||
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
|
||||
string unique_name() const { return unique_name_; }
|
||||
vector<PatternPtr> inputs() { return inputs_; }
|
||||
bool should_replace() { return should_replace_; }
|
||||
virtual void reset() {}
|
||||
|
||||
protected:
|
||||
static int g_id_;
|
||||
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
|
||||
string unique_name_;
|
||||
vector<PatternPtr> inputs_;
|
||||
bool should_replace_ = true;
|
||||
};
|
||||
|
||||
struct PatternEqual {
|
||||
bool operator()(PatternPtr const &p1, PatternPtr const &p2) const {
|
||||
MS_EXCEPTION_IF_NULL(p1);
|
||||
MS_EXCEPTION_IF_NULL(p2);
|
||||
return p1->unique_name() == p2->unique_name();
|
||||
}
|
||||
};
|
||||
|
||||
struct PatternHasher {
|
||||
std::size_t operator()(PatternPtr const &p) const {
|
||||
MS_EXCEPTION_IF_NULL(p);
|
||||
return std::hash<string>()(p->unique_name());
|
||||
}
|
||||
};
|
||||
|
||||
class IsPrimTypeOf : public Pattern {
|
||||
public:
|
||||
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
|
||||
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
|
||||
: primitives_(prims), name_(name), matched_prim_(nullptr) {
|
||||
unique_name_ = std::to_string(g_id_++) + "_" + name;
|
||||
should_replace_ = should_replace;
|
||||
if (!should_replace) {
|
||||
matched_prim_ = prims[0];
|
||||
}
|
||||
}
|
||||
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "_" + name;
|
||||
// Make primitives_
|
||||
for (auto &iter : types) {
|
||||
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
|
||||
}
|
||||
should_replace_ = should_replace;
|
||||
if (!should_replace) {
|
||||
matched_prim_ = primitives_[0];
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
PrimitivePyPtr matched_primitive() { return matched_prim_; }
|
||||
void reset() override {
|
||||
if (should_replace_) {
|
||||
matched_prim_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<string> types_;
|
||||
vector<PrimitivePyPtr> primitives_;
|
||||
string name_;
|
||||
PrimitivePyPtr matched_prim_;
|
||||
};
|
||||
|
||||
class CallWith : public Pattern {
|
||||
public:
|
||||
CallWith() { unique_name_ = std::to_string(g_id_++); }
|
||||
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
|
||||
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
|
||||
prim_pattern_ = prim_pattern;
|
||||
unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
|
||||
prim_ = prim;
|
||||
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
|
||||
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
|
||||
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
MS_DECLARE_PARENT(CallWith, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
PrimitivePtr prim_value() { return prim_; }
|
||||
PatternPtr prim_pattern() { return prim_pattern_; }
|
||||
|
||||
private:
|
||||
PatternPtr prim_pattern_ = nullptr;
|
||||
PrimitivePtr prim_ = nullptr;
|
||||
vector<string> types_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
class IsIn : public Pattern {
|
||||
public:
|
||||
IsIn() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++);
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsIn, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
vector<PatternPtr> patterns_;
|
||||
};
|
||||
|
||||
class IsNot : public Pattern {
|
||||
public:
|
||||
IsNot() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++);
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsNot, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
vector<PatternPtr> patterns_;
|
||||
};
|
||||
|
||||
class AnyPattern : public Pattern {
|
||||
public:
|
||||
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
|
||||
MS_DECLARE_PARENT(AnyPattern, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
class NewTensor : public Pattern {
|
||||
public:
|
||||
NewTensor() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
|
||||
MS_DECLARE_PARENT(NewTensor, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override {
|
||||
MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
|
||||
}
|
||||
tensor::TensorPtr input_tensor() { return input_tensor_; }
|
||||
|
||||
private:
|
||||
tensor::TensorPtr input_tensor_;
|
||||
};
|
||||
|
||||
class MatchResult {
|
||||
public:
|
||||
MatchResult() {}
|
||||
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
|
||||
PatternNodeMap _result() { return match_result_; }
|
||||
AnfNodePtr get_node(const PatternPtr &pattern);
|
||||
void merge(const MatchResultPtr &other_result);
|
||||
void clear() { match_result_.clear(); }
|
||||
void dump() {
|
||||
MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n";
|
||||
for (auto &iter : match_result_) {
|
||||
MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
PatternNodeMap match_result_;
|
||||
};
|
||||
} // namespace python_pass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue