|
|
|
@ -42,7 +42,7 @@ void Var::EnsureTag() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool operator==(const VarPtr& lhs, const VarPtr& rhs) {
|
|
|
|
|
bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
|
|
|
|
|
if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
|
|
|
|
|
CondVarPtr v1 = dyn_cast<CondVar>(lhs);
|
|
|
|
|
CondVarPtr v2 = dyn_cast<CondVar>(rhs);
|
|
|
|
@ -63,7 +63,7 @@ std::string SeqVar::ToString() const {
|
|
|
|
|
return buffer.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::ostream& operator<<(std::ostream& os, const VarPtr& var) {
|
|
|
|
|
std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
os << "";
|
|
|
|
|
} else {
|
|
|
|
@ -73,10 +73,10 @@ std::ostream& operator<<(std::ostream& os, const VarPtr& var) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::ostream& operator<<<VarPtr, BaseRef>(std::ostream& os, const Equiv& equiv) {
|
|
|
|
|
std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
|
|
|
|
|
os << "[Equiv]"
|
|
|
|
|
<< "\n";
|
|
|
|
|
for (auto& equiv_item : equiv) {
|
|
|
|
|
for (auto &equiv_item : equiv) {
|
|
|
|
|
auto k = equiv_item.first;
|
|
|
|
|
os << k << ":";
|
|
|
|
|
BaseRef x = equiv_item.second;
|
|
|
|
@ -104,7 +104,7 @@ std::ostream& operator<<<VarPtr, BaseRef>(std::ostream& os, const Equiv& equiv)
|
|
|
|
|
return os;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static BaseRef GetVar(const BaseRef& x) {
|
|
|
|
|
static BaseRef GetVar(const BaseRef &x) {
|
|
|
|
|
MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
|
|
|
|
|
if (utils::isa<AnfNodePtr>(x)) {
|
|
|
|
|
auto node = utils::cast<AnfNodePtr>(x);
|
|
|
|
@ -129,7 +129,7 @@ static BaseRef GetVar(const BaseRef& x) {
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) {
|
|
|
|
|
EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
|
|
|
|
|
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
|
if (utils::isa<VarPtr>(pattern)) {
|
|
|
|
@ -144,8 +144,8 @@ EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv)
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr_ref, VectorRef* const values_pattern,
|
|
|
|
|
VectorRef* const values_expr) const {
|
|
|
|
|
bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
|
|
|
|
|
VectorRef *const values_expr) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(values_expr);
|
|
|
|
|
if (utils::isa<SeqPtr>(pattern_ref)) {
|
|
|
|
|
*values_pattern = pattern_ref;
|
|
|
|
@ -155,12 +155,12 @@ bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref, VectorRef* const values_pattern,
|
|
|
|
|
VectorRef* const values_expr) const {
|
|
|
|
|
bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
|
|
|
|
|
VectorRef *const values_expr) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(values_expr);
|
|
|
|
|
// visitor to visite the list
|
|
|
|
|
auto appender_pattern = [](VectorRef& values) {
|
|
|
|
|
std::function<BaseRef(const BaseRef&)> fn = [&](const BaseRef& u) {
|
|
|
|
|
auto appender_pattern = [](VectorRef &values) {
|
|
|
|
|
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
|
|
|
|
|
values.push_back(GetVar(u));
|
|
|
|
|
return u;
|
|
|
|
|
};
|
|
|
|
@ -174,8 +174,8 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto appender_expr = [](VectorRef& values) {
|
|
|
|
|
std::function<BaseRef(const BaseRef&)> fn = [&](const BaseRef& u) {
|
|
|
|
|
auto appender_expr = [](VectorRef &values) {
|
|
|
|
|
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
|
|
|
|
|
values.push_back(u);
|
|
|
|
|
return u;
|
|
|
|
|
};
|
|
|
|
@ -187,10 +187,10 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref
|
|
|
|
|
return visitor_->Visit(expr_ref, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static int GetSVarStartIndex(const VectorRef& values) {
|
|
|
|
|
static int GetSVarStartIndex(const VectorRef &values) {
|
|
|
|
|
int index = -1;
|
|
|
|
|
int count = 0;
|
|
|
|
|
for (auto& value : values) {
|
|
|
|
|
for (auto &value : values) {
|
|
|
|
|
if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
|
|
|
|
|
if (index != -1) {
|
|
|
|
|
MS_LOG(DEBUG) << "Multiple SVars in sequence";
|
|
|
|
@ -203,7 +203,35 @@ static int GetSVarStartIndex(const VectorRef& values) {
|
|
|
|
|
return index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorRef& values_expr, EquivPtr equiv) const {
|
|
|
|
|
void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
|
|
|
|
|
EquivPtr equiv) {
|
|
|
|
|
if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
|
|
|
|
|
!utils::isa<AnfNodePtr>(expr_ref)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto real_node = utils::cast<AnfNodePtr>(expr_ref);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_node);
|
|
|
|
|
if (!real_node->isa<CNode>()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_node);
|
|
|
|
|
if (!IsValueNode<Primitive>(prim_node)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
ValuePtr value = GetValueNode(prim_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
|
auto prim = value->cast<PrimitivePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
auto iter = primitive_vars.find(prim);
|
|
|
|
|
if (iter == primitive_vars.end()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
(*equiv)[iter->second] = real_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
|
|
|
|
|
const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
|
|
|
|
|
int svar_index = GetSVarStartIndex(values_pattern);
|
|
|
|
|
if (svar_index == kInvalidVarIndex) {
|
|
|
|
|
return nullptr;
|
|
|
|
@ -229,12 +257,12 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR
|
|
|
|
|
if (svar_index != -1 && i == IntToSize(svar_index)) {
|
|
|
|
|
auto seq =
|
|
|
|
|
std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
|
|
|
|
|
equiv = Match(values_pattern[svar_index], seq, equiv);
|
|
|
|
|
equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
|
|
|
|
|
} else {
|
|
|
|
|
if (svar_index != -1 && i > IntToSize(svar_index)) {
|
|
|
|
|
expr_i = i + diff - 1;
|
|
|
|
|
}
|
|
|
|
|
equiv = Match(values_pattern[i], values_expr[expr_i], equiv);
|
|
|
|
|
equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
|
|
|
|
|
}
|
|
|
|
|
if (equiv == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
@ -243,7 +271,8 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR
|
|
|
|
|
return equiv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) const {
|
|
|
|
|
EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
|
|
|
|
|
EquivPtr equiv) const {
|
|
|
|
|
MS_LOG(DEBUG) << "-----[in Match]";
|
|
|
|
|
MS_LOG(DEBUG) << "GetVar w";
|
|
|
|
|
BaseRef pattern_ref = GetVar(pattern);
|
|
|
|
@ -292,10 +321,12 @@ EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, Equiv
|
|
|
|
|
// 6. if any svar in both side, find the SeqVar index,
|
|
|
|
|
// try to pack the Var s in std::vector to a Seq and match elements one by one.
|
|
|
|
|
// check svar
|
|
|
|
|
return AlignSVar(values_pattern, values_expr, equiv);
|
|
|
|
|
equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
|
|
|
|
|
UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
|
|
|
|
|
return equiv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) const {
|
|
|
|
|
BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
|
MS_LOG(DEBUG) << "-----[in Replace]";
|
|
|
|
|
BaseRef ref = GetVar(pattern);
|
|
|
|
@ -304,7 +335,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co
|
|
|
|
|
|
|
|
|
|
// w is var
|
|
|
|
|
if (utils::isa<VarPtr>(ref)) {
|
|
|
|
|
const VarPtr& var = utils::cast<VarPtr>(ref);
|
|
|
|
|
const VarPtr &var = utils::cast<VarPtr>(ref);
|
|
|
|
|
auto iter = equiv->find(var);
|
|
|
|
|
if (iter != equiv->end()) {
|
|
|
|
|
out = iter->second;
|
|
|
|
@ -316,7 +347,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// visitor to visit the list
|
|
|
|
|
std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef& u) { return Replace(u, equiv); };
|
|
|
|
|
std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); };
|
|
|
|
|
|
|
|
|
|
visitor_->SetFn(fn);
|
|
|
|
|
BaseRef visit_out;
|
|
|
|
|