You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/frontend/optimizer/pattern.h

229 lines
7.3 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
#include <string>
#include <memory>
#include <vector>
#include <unordered_map>
#include "base/base.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/primitive_py.h"
#include "utils/tensor_py.h"
namespace mindspore {
namespace opt {
namespace python_pass {
using std::string;
using std::vector;
class MatchResult;
using MatchResultPtr = std::shared_ptr<MatchResult>;
class Pattern;
using PatternPtr = std::shared_ptr<Pattern>;
class IsPrimTypeOf;
using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>;
class CallWith;
using CallWithPtr = std::shared_ptr<CallWith>;
class NewTensor;
using NewTensorPtr = std::shared_ptr<NewTensor>;
struct PatternHasher;
struct PatternEqual;
using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
class Pattern : public Base {
public:
Pattern() : unique_name_(std::to_string(g_id_++)) {}
virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; }
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
string unique_name() const { return unique_name_; }
vector<PatternPtr> inputs() { return inputs_; }
bool should_replace() { return should_replace_; }
virtual void reset() {}
protected:
static int g_id_;
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
string unique_name_;
vector<PatternPtr> inputs_;
bool should_replace_ = true;
};
struct PatternEqual {
bool operator()(PatternPtr const &p1, PatternPtr const &p2) const {
MS_EXCEPTION_IF_NULL(p1);
MS_EXCEPTION_IF_NULL(p2);
return p1->unique_name() == p2->unique_name();
}
};
struct PatternHasher {
std::size_t operator()(PatternPtr const &p) const {
MS_EXCEPTION_IF_NULL(p);
return std::hash<string>()(p->unique_name());
}
};
class IsPrimTypeOf : public Pattern {
public:
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
: primitives_(prims), name_(name), matched_prim_(nullptr) {
unique_name_ = std::to_string(g_id_++) + "_" + name;
should_replace_ = should_replace;
if (!should_replace) {
matched_prim_ = prims[0];
}
}
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
unique_name_ = std::to_string(g_id_++) + "_" + name;
// Make primitives_
for (auto &iter : types) {
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
}
should_replace_ = should_replace;
if (!should_replace) {
matched_prim_ = primitives_[0];
}
}
MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
PrimitivePyPtr matched_primitive() { return matched_prim_; }
void reset() override {
if (should_replace_) {
matched_prim_ = nullptr;
}
}
private:
vector<string> types_;
vector<PrimitivePyPtr> primitives_;
string name_;
PrimitivePyPtr matched_prim_;
};
class CallWith : public Pattern {
public:
CallWith() { unique_name_ = std::to_string(g_id_++); }
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
prim_pattern_ = prim_pattern;
unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name();
inputs_ = inputs;
should_replace_ = should_replace;
}
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
prim_ = prim;
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
MS_DECLARE_PARENT(CallWith, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
PrimitivePtr prim_value() { return prim_; }
PatternPtr prim_pattern() { return prim_pattern_; }
private:
PatternPtr prim_pattern_ = nullptr;
PrimitivePtr prim_ = nullptr;
vector<string> types_;
string name_;
};
class IsIn : public Pattern {
public:
IsIn() { unique_name_ = std::to_string(g_id_++); }
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++);
for (auto &iter : patterns) {
unique_name_ = unique_name_ + "_" + iter->unique_name();
}
}
MS_DECLARE_PARENT(IsIn, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
private:
vector<PatternPtr> patterns_;
};
class IsNot : public Pattern {
public:
IsNot() { unique_name_ = std::to_string(g_id_++); }
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++);
for (auto &iter : patterns) {
unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name();
}
}
MS_DECLARE_PARENT(IsNot, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
private:
vector<PatternPtr> patterns_;
};
class AnyPattern : public Pattern {
public:
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
MS_DECLARE_PARENT(AnyPattern, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
};
class NewTensor : public Pattern {
public:
NewTensor() { unique_name_ = std::to_string(g_id_++); }
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
MS_DECLARE_PARENT(NewTensor, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override {
MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
}
tensor::TensorPtr input_tensor() { return input_tensor_; }
private:
tensor::TensorPtr input_tensor_;
};
class MatchResult {
public:
MatchResult() {}
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
PatternNodeMap _result() { return match_result_; }
AnfNodePtr get_node(const PatternPtr &pattern);
void merge(const MatchResultPtr &other_result);
void clear() { match_result_.clear(); }
void dump() {
MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n";
for (auto &iter : match_result_) {
MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n";
}
}
private:
PatternNodeMap match_result_;
};
} // namespace python_pass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_