|
|
|
@ -49,6 +49,7 @@ using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher,
|
|
|
|
|
class Pattern : public Base {
|
|
|
|
|
public:
|
|
|
|
|
Pattern() : unique_name_(std::to_string(g_id_++)) {}
|
|
|
|
|
~Pattern() = default;
|
|
|
|
|
virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; }
|
|
|
|
|
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
|
|
|
|
|
string unique_name() const { return unique_name_; }
|
|
|
|
@ -82,6 +83,7 @@ struct PatternHasher {
|
|
|
|
|
class IsPrimTypeOf : public Pattern {
|
|
|
|
|
public:
|
|
|
|
|
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
|
|
|
|
|
~IsPrimTypeOf() = default;
|
|
|
|
|
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
|
|
|
|
|
: primitives_(prims), name_(name), matched_prim_(nullptr) {
|
|
|
|
|
unique_name_ = std::to_string(g_id_++) + "_" + name;
|
|
|
|
@ -120,6 +122,7 @@ class IsPrimTypeOf : public Pattern {
|
|
|
|
|
class CallWith : public Pattern {
|
|
|
|
|
public:
|
|
|
|
|
CallWith() { unique_name_ = std::to_string(g_id_++); }
|
|
|
|
|
~CallWith() = default;
|
|
|
|
|
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
|
|
|
|
|
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
|
|
|
|
|
prim_pattern_ = prim_pattern;
|
|
|
|
@ -154,6 +157,7 @@ class CallWith : public Pattern {
|
|
|
|
|
class IsIn : public Pattern {
|
|
|
|
|
public:
|
|
|
|
|
IsIn() { unique_name_ = std::to_string(g_id_++); }
|
|
|
|
|
~IsIn() = default;
|
|
|
|
|
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
|
|
|
|
|
unique_name_ = std::to_string(g_id_++);
|
|
|
|
|
for (auto &iter : patterns) {
|
|
|
|
@ -170,6 +174,7 @@ class IsIn : public Pattern {
|
|
|
|
|
class IsNot : public Pattern {
|
|
|
|
|
public:
|
|
|
|
|
IsNot() { unique_name_ = std::to_string(g_id_++); }
|
|
|
|
|
~IsNot() = default;
|
|
|
|
|
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
|
|
|
|
|
unique_name_ = std::to_string(g_id_++);
|
|
|
|
|
for (auto &iter : patterns) {
|
|
|
|
@ -186,6 +191,7 @@ class IsNot : public Pattern {
|
|
|
|
|
class AnyPattern : public Pattern {
|
|
|
|
|
public:
|
|
|
|
|
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
|
|
|
|
|
~AnyPattern() = default;
|
|
|
|
|
MS_DECLARE_PARENT(AnyPattern, Pattern);
|
|
|
|
|
MatchResultPtr match(const AnfNodePtr &node) override;
|
|
|
|
|
};
|
|
|
|
@ -193,6 +199,7 @@ class AnyPattern : public Pattern {
|
|
|
|
|
class NewTensor : public Pattern {
|
|
|
|
|
public:
|
|
|
|
|
NewTensor() { unique_name_ = std::to_string(g_id_++); }
|
|
|
|
|
~NewTensor() = default;
|
|
|
|
|
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
|
|
|
|
|
MS_DECLARE_PARENT(NewTensor, Pattern);
|
|
|
|
|
MatchResultPtr match(const AnfNodePtr &node) override {
|
|
|
|
@ -207,6 +214,7 @@ class NewTensor : public Pattern {
|
|
|
|
|
class MatchResult {
|
|
|
|
|
public:
|
|
|
|
|
MatchResult() {}
|
|
|
|
|
~MatchResult() = default;
|
|
|
|
|
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
|
|
|
|
|
PatternNodeMap _result() { return match_result_; }
|
|
|
|
|
AnfNodePtr get_node(const PatternPtr &pattern);
|
|
|
|
|