!289 Add cnode mapping after graph match

Merge pull request !289 from YuJianfeng/find_op
pull/289/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b571fabd77

@ -15,43 +15,9 @@
*/ */
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "pre_activate/common/helper.h" #include "pre_activate/common/helper.h"
#include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace {
void GetAdd0AndAdd1(const AnfNodePtr &sub0, AnfNodePtr *add0, AnfNodePtr *add1) {
MS_EXCEPTION_IF_NULL(sub0);
MS_EXCEPTION_IF_NULL(add0);
MS_EXCEPTION_IF_NULL(add1);
auto sub0_cnode = sub0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub0_cnode);
CheckCNodeInputSize(sub0_cnode, kSubInputNum);
AnfNodePtr mul4 = sub0_cnode->input(2);
MS_EXCEPTION_IF_NULL(mul4);
auto mul4_cnode = mul4->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul4_cnode);
CheckCNodeInputSize(mul4_cnode, kMulInputNum);
AnfNodePtr true_div0 = mul4_cnode->input(2);
MS_EXCEPTION_IF_NULL(true_div0);
auto true_div0_cnode = true_div0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(true_div0_cnode);
CheckCNodeInputSize(true_div0_cnode, kRealDivInputNum);
*add0 = true_div0_cnode->input(1);
AnfNodePtr add2 = true_div0_cnode->input(2);
MS_EXCEPTION_IF_NULL(add2);
auto add2_cnode = add2->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add2_cnode);
CheckCNodeInputSize(add2_cnode, kAddInputNum);
AnfNodePtr sqrt0 = add2_cnode->input(1);
MS_EXCEPTION_IF_NULL(sqrt0);
auto sqrt0_cnode = sqrt0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sqrt0_cnode);
CheckCNodeInputSize(sqrt0_cnode, kSqrtInputNum);
*add1 = sqrt0_cnode->input(1);
}
} // namespace
AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
@ -79,10 +45,10 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const {
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({prim::kPrimTensorAdd, mul2, mul3})}); VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({prim::kPrimTensorAdd, mul0, mul1}); VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
} }
@ -96,10 +62,17 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
new_node->set_scope(node->scope()); new_node->set_scope(node->scope());
// Set abstract of new node // Set abstract of new node
AbstractBasePtrList new_node_abstract_list; AbstractBasePtrList new_node_abstract_list;
AnfNodePtr add0 = nullptr; auto iter_add0 = (*equiv).find(add0_var_);
AnfNodePtr add1 = nullptr; if (iter_add0 == (*equiv).end()) {
GetAdd0AndAdd1(node, &add0, &add1); MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched.";
}
auto iter_add1 = (*equiv).find(add1_var_);
if (iter_add1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched.";
}
auto add0 = utils::cast<AnfNodePtr>(iter_add0->second);
MS_EXCEPTION_IF_NULL(add0); MS_EXCEPTION_IF_NULL(add0);
auto add1 = utils::cast<AnfNodePtr>(iter_add1->second);
MS_EXCEPTION_IF_NULL(add1); MS_EXCEPTION_IF_NULL(add1);
new_node_abstract_list.push_back(add1->abstract()); new_node_abstract_list.push_back(add1->abstract());
new_node_abstract_list.push_back(add0->abstract()); new_node_abstract_list.push_back(add0->abstract());

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "pre_activate/common/optimizer.h" #include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -35,6 +36,8 @@ class AdamApplyOneFusion : public PatternProcessPass {
mul_x_input_vars_.push_back(std::make_shared<Var>()); mul_x_input_vars_.push_back(std::make_shared<Var>());
} }
add2_y_ = std::make_shared<Var>(); add2_y_ = std::make_shared<Var>();
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
} }
~AdamApplyOneFusion() override = default; ~AdamApplyOneFusion() override = default;
@ -46,6 +49,8 @@ class AdamApplyOneFusion : public PatternProcessPass {
std::vector<VarPtr> input_vars_; std::vector<VarPtr> input_vars_;
std::vector<VarPtr> mul_x_input_vars_; std::vector<VarPtr> mul_x_input_vars_;
VarPtr add2_y_; VarPtr add2_y_;
VarPtr add0_var_;
VarPtr add1_var_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -17,48 +17,13 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <tuple>
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "ir/primitive.h" #include "ir/primitive.h"
#include "utils/utils.h"
#include "pre_activate/common/helper.h" #include "pre_activate/common/helper.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace {
std::tuple<AnfNodePtr, AnfNodePtr> GetAdd0Add1Node(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto sub0 = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub0);
auto mul5_anf = sub0->input(2);
MS_EXCEPTION_IF_NULL(mul5_anf);
auto mul5 = mul5_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul5);
auto add3_anf = mul5->input(2);
MS_EXCEPTION_IF_NULL(add3_anf);
auto add3 = add3_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add3);
auto real_div0_anf = add3->input(1);
MS_EXCEPTION_IF_NULL(real_div0_anf);
auto real_div0 = real_div0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div0);
auto add0_anf = real_div0->input(1);
MS_EXCEPTION_IF_NULL(add0_anf);
auto add2_anf = real_div0->input(2);
MS_EXCEPTION_IF_NULL(add2_anf);
auto add2 = add2_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add2);
auto sqrt0_anf = add2->input(1);
MS_EXCEPTION_IF_NULL(sqrt0_anf);
auto sqrt0 = sqrt0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sqrt0);
auto add1_anf = sqrt0->input(1);
MS_EXCEPTION_IF_NULL(add1_anf);
return std::make_tuple(add0_anf, add1_anf);
}
} // namespace
std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]); auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]);
@ -82,10 +47,10 @@ const BaseRef AdamApplyOneWithDecayRule::DefinePattern() const {
VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_}); VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_});
VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_}); VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_});
VectorRef square0_pattern({prim::kPrimSquare, input0_}); VectorRef square0_pattern({prim::kPrimSquare, input0_});
VectorRef add0_pattern({prim::kPrimTensorAdd, mul0_pattern, mul1_pattern}); VectorRef add0_pattern({add0_var_, mul0_pattern, mul1_pattern});
VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_}); VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_});
VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern}); VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern});
VectorRef add1_pattern({prim::kPrimTensorAdd, mul2_pattern, mul3_pattern}); VectorRef add1_pattern({add1_var_, mul2_pattern, mul3_pattern});
VectorRef sqrt0_pattern({sqrt, add1_pattern}); VectorRef sqrt0_pattern({sqrt, add1_pattern});
VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_}); VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_});
VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_}); VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_});
@ -107,9 +72,18 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(fusion_node); MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(node->scope()); fusion_node->set_scope(node->scope());
AnfNodePtr add0 = nullptr; auto iter_add0 = (*equiv).find(add0_var_);
AnfNodePtr add1 = nullptr; if (iter_add0 == (*equiv).end()) {
std::tie(add0, add1) = GetAdd0Add1Node(node); MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched.";
}
auto iter_add1 = (*equiv).find(add1_var_);
if (iter_add1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched.";
}
auto add0 = utils::cast<AnfNodePtr>(iter_add0->second);
MS_EXCEPTION_IF_NULL(add0);
auto add1 = utils::cast<AnfNodePtr>(iter_add1->second);
MS_EXCEPTION_IF_NULL(add1);
auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0),
AnfAlgo::GetOutputInferDataType(node, 0)}; AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0),

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "pre_activate/common/optimizer.h" #include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class AdamApplyOneWithDecayRule : public PatternProcessPass { class AdamApplyOneWithDecayRule : public PatternProcessPass {
@ -36,6 +37,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
mul3_x_ = std::make_shared<Var>(); mul3_x_ = std::make_shared<Var>();
mul4_x_ = std::make_shared<Var>(); mul4_x_ = std::make_shared<Var>();
add2_y_ = std::make_shared<Var>(); add2_y_ = std::make_shared<Var>();
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
} }
~AdamApplyOneWithDecayRule() override = default; ~AdamApplyOneWithDecayRule() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
@ -54,6 +57,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
VarPtr mul3_x_; VarPtr mul3_x_;
VarPtr mul4_x_; VarPtr mul4_x_;
VarPtr add2_y_; VarPtr add2_y_;
VarPtr add0_var_;
VarPtr add1_var_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -16,36 +16,9 @@
#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" #include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h"
#include <vector> #include <vector>
#include "pre_activate/common/helper.h" #include "pre_activate/common/helper.h"
#include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace {
AnfNodePtr GetAdd1Node(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto add2_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add2_cnode);
if (add2_cnode->inputs().size() != kAddInputNum) {
MS_LOG(ERROR) << "The input size of Add2 is not equal to " << kAddInputNum;
}
AnfNodePtr sqrt0 = add2_cnode->input(1);
MS_EXCEPTION_IF_NULL(sqrt0);
auto sqrt0_cnode = sqrt0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sqrt0_cnode);
if (sqrt0_cnode->inputs().size() != kSqrtInputNum) {
MS_LOG(ERROR) << "The input size of Sqrt0 is not equal to " << kSqrtInputNum;
}
AnfNodePtr real_div1 = sqrt0_cnode->input(1);
MS_EXCEPTION_IF_NULL(real_div1);
auto real_div1_cnode = real_div1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div1_cnode);
if (real_div1_cnode->inputs().size() != kMulInputNum) {
MS_LOG(ERROR) << "The input size of RealDiv1 is not equal to " << kMulInputNum;
}
return real_div1_cnode->input(1);
}
} // namespace
AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
@ -79,7 +52,7 @@ const BaseRef LambNextRightRule::DefinePattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt); MS_EXCEPTION_IF_NULL(prim_sqrt);
VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})});
VectorRef add1 = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3});
return VectorRef( return VectorRef(
{prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_});
} }
@ -91,7 +64,11 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons
auto new_node = CreateLambNextRightNode(func_graph, equiv); auto new_node = CreateLambNextRightNode(func_graph, equiv);
MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(new_node);
// Set abstract of new node // Set abstract of new node
AnfNodePtr add1 = GetAdd1Node(node); auto iter_add1 = (*equiv).find(add1_var_);
if (iter_add1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched.";
}
auto add1 = utils::cast<AnfNodePtr>(iter_add1->second);
MS_EXCEPTION_IF_NULL(add1); MS_EXCEPTION_IF_NULL(add1);
AbstractBasePtrList new_node_abstract_list; AbstractBasePtrList new_node_abstract_list;
new_node_abstract_list.push_back(add1->abstract()); new_node_abstract_list.push_back(add1->abstract());

@ -18,6 +18,8 @@
#include <memory> #include <memory>
#include "pre_activate/common/optimizer.h" #include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class LambNextRightRule : public PatternProcessPass { class LambNextRightRule : public PatternProcessPass {
@ -29,7 +31,8 @@ class LambNextRightRule : public PatternProcessPass {
mul2_x_(std::make_shared<Var>()), mul2_x_(std::make_shared<Var>()),
mul3_x_(std::make_shared<Var>()), mul3_x_(std::make_shared<Var>()),
true_div1_recip_(std::make_shared<Var>()), true_div1_recip_(std::make_shared<Var>()),
add2_y_(std::make_shared<Var>()) {} add2_y_(std::make_shared<Var>()),
add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()))) {}
~LambNextRightRule() override = default; ~LambNextRightRule() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
@ -44,6 +47,7 @@ class LambNextRightRule : public PatternProcessPass {
VarPtr mul3_x_; VarPtr mul3_x_;
VarPtr true_div1_recip_; VarPtr true_div1_recip_;
VarPtr add2_y_; VarPtr add2_y_;
VarPtr add1_var_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -30,7 +30,8 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool multigraph); AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph);
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) { if (utils::isa<int>(sexp)) {
@ -71,12 +72,20 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
return nullptr; return nullptr;
} }
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, bool multigraph = false) { AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph = false) {
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
MS_EXCEPTION_IF_NULL(primitive_vars);
if (utils::isa<VectorRef>(sexp)) { if (utils::isa<VectorRef>(sexp)) {
return HandleSexpVector(sexp, graph, multigraph); return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
} }
if (utils::isa<VarPtr>(sexp)) { if (utils::isa<VarPtr>(sexp)) {
auto var_ptr = utils::cast<VarPtr>(sexp);
MS_EXCEPTION_IF_NULL(var_ptr);
if (var_ptr->primitive()) {
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
return NewValueNode(var_ptr->primitive());
}
return CreateVarNodeWithSexp(sexp, graph); return CreateVarNodeWithSexp(sexp, graph);
} }
if (utils::isa<AnfNodePtr>(sexp)) { if (utils::isa<AnfNodePtr>(sexp)) {
@ -89,13 +98,14 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, bool multigraph
return value_node; return value_node;
} }
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool multigraph) { AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
std::vector<AnfNodePtr> input_nodes; std::vector<AnfNodePtr> input_nodes;
const auto &tuple = utils::cast<VectorRef>(sexp); const auto &tuple = utils::cast<VectorRef>(sexp);
if (multigraph && utils::isa<VarPtr>(graph)) { if (multigraph && utils::isa<VarPtr>(graph)) {
for (auto &x : tuple) { for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), true); AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
input_nodes.push_back(node); input_nodes.push_back(node);
} }
VarPtr var_ptr = utils::cast<VarPtr>(graph); VarPtr var_ptr = utils::cast<VarPtr>(graph);
@ -103,7 +113,7 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool mult
} }
for (auto &x : tuple) { for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, graph, multigraph); AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
input_nodes.push_back(node); input_nodes.push_back(node);
} }
return CreateCNodeWithGraph(input_nodes, graph); return CreateCNodeWithGraph(input_nodes, graph);
@ -166,7 +176,8 @@ PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
multigraph_(multigraph), multigraph_(multigraph),
pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(), pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual), std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))) {} std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
const BaseRef PatternProcessPass::DefinePattern() const { const BaseRef PatternProcessPass::DefinePattern() const {
VarPtr X = std::make_shared<Var>(); VarPtr X = std::make_shared<Var>();
@ -176,7 +187,7 @@ const BaseRef PatternProcessPass::DefinePattern() const {
void PatternProcessPass::Build() { void PatternProcessPass::Build() {
VarPtr fg = std::make_shared<Var>("RootG"); VarPtr fg = std::make_shared<Var>("RootG");
BaseRef pattern = std::move(DefinePattern()); BaseRef pattern = std::move(DefinePattern());
pattern_ = SexpToNode(pattern, fg, multigraph_); pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_);
} }
AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
@ -185,7 +196,8 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
} }
auto empty_equiv = std::make_shared<Equiv>(); auto empty_equiv = std::make_shared<Equiv>();
EquivPtr equiv = pattern_engine_.Match(pattern_, node, empty_equiv); MS_EXCEPTION_IF_NULL(primitive_vars_);
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
if (equiv != nullptr && !equiv->empty()) { if (equiv != nullptr && !equiv->empty()) {
return Process(func_graph, node, equiv); return Process(func_graph, node, equiv);
} }

@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
@ -46,6 +47,7 @@ class PatternProcessPass : public NodePass {
AnfNodePtr pattern_ = nullptr; AnfNodePtr pattern_ = nullptr;
bool multigraph_ = true; bool multigraph_ = true;
PatternEngine pattern_engine_; PatternEngine pattern_engine_;
PrimitiveVarMapPtr primitive_vars_;
}; };
class GraphOptimizer { class GraphOptimizer {

@ -42,7 +42,7 @@ void Var::EnsureTag() {
} }
} }
bool operator==(const VarPtr& lhs, const VarPtr& rhs) { bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) { if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
CondVarPtr v1 = dyn_cast<CondVar>(lhs); CondVarPtr v1 = dyn_cast<CondVar>(lhs);
CondVarPtr v2 = dyn_cast<CondVar>(rhs); CondVarPtr v2 = dyn_cast<CondVar>(rhs);
@ -63,7 +63,7 @@ std::string SeqVar::ToString() const {
return buffer.str(); return buffer.str();
} }
std::ostream& operator<<(std::ostream& os, const VarPtr& var) { std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
if (var == nullptr) { if (var == nullptr) {
os << ""; os << "";
} else { } else {
@ -73,10 +73,10 @@ std::ostream& operator<<(std::ostream& os, const VarPtr& var) {
} }
template <> template <>
std::ostream& operator<<<VarPtr, BaseRef>(std::ostream& os, const Equiv& equiv) { std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
os << "[Equiv]" os << "[Equiv]"
<< "\n"; << "\n";
for (auto& equiv_item : equiv) { for (auto &equiv_item : equiv) {
auto k = equiv_item.first; auto k = equiv_item.first;
os << k << ":"; os << k << ":";
BaseRef x = equiv_item.second; BaseRef x = equiv_item.second;
@ -104,7 +104,7 @@ std::ostream& operator<<<VarPtr, BaseRef>(std::ostream& os, const Equiv& equiv)
return os; return os;
} }
static BaseRef GetVar(const BaseRef& x) { static BaseRef GetVar(const BaseRef &x) {
MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
if (utils::isa<AnfNodePtr>(x)) { if (utils::isa<AnfNodePtr>(x)) {
auto node = utils::cast<AnfNodePtr>(x); auto node = utils::cast<AnfNodePtr>(x);
@ -129,7 +129,7 @@ static BaseRef GetVar(const BaseRef& x) {
return x; return x;
} }
EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) { EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
if (utils::isa<VarPtr>(pattern)) { if (utils::isa<VarPtr>(pattern)) {
@ -144,8 +144,8 @@ EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv)
return nullptr; return nullptr;
} }
bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr_ref, VectorRef* const values_pattern, bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
VectorRef* const values_expr) const { VectorRef *const values_expr) const {
MS_EXCEPTION_IF_NULL(values_expr); MS_EXCEPTION_IF_NULL(values_expr);
if (utils::isa<SeqPtr>(pattern_ref)) { if (utils::isa<SeqPtr>(pattern_ref)) {
*values_pattern = pattern_ref; *values_pattern = pattern_ref;
@ -155,12 +155,12 @@ bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr
return false; return false;
} }
bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref, VectorRef* const values_pattern, bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
VectorRef* const values_expr) const { VectorRef *const values_expr) const {
MS_EXCEPTION_IF_NULL(values_expr); MS_EXCEPTION_IF_NULL(values_expr);
// visitor to visite the list // visitor to visite the list
auto appender_pattern = [](VectorRef& values) { auto appender_pattern = [](VectorRef &values) {
std::function<BaseRef(const BaseRef&)> fn = [&](const BaseRef& u) { std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
values.push_back(GetVar(u)); values.push_back(GetVar(u));
return u; return u;
}; };
@ -174,8 +174,8 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref
return false; return false;
} }
auto appender_expr = [](VectorRef& values) { auto appender_expr = [](VectorRef &values) {
std::function<BaseRef(const BaseRef&)> fn = [&](const BaseRef& u) { std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
values.push_back(u); values.push_back(u);
return u; return u;
}; };
@ -187,10 +187,10 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref
return visitor_->Visit(expr_ref, nullptr); return visitor_->Visit(expr_ref, nullptr);
} }
static int GetSVarStartIndex(const VectorRef& values) { static int GetSVarStartIndex(const VectorRef &values) {
int index = -1; int index = -1;
int count = 0; int count = 0;
for (auto& value : values) { for (auto &value : values) {
if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) { if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
if (index != -1) { if (index != -1) {
MS_LOG(DEBUG) << "Multiple SVars in sequence"; MS_LOG(DEBUG) << "Multiple SVars in sequence";
@ -203,7 +203,35 @@ static int GetSVarStartIndex(const VectorRef& values) {
return index; return index;
} }
EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorRef& values_expr, EquivPtr equiv) const { void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
EquivPtr equiv) {
if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
!utils::isa<AnfNodePtr>(expr_ref)) {
return;
}
auto real_node = utils::cast<AnfNodePtr>(expr_ref);
MS_EXCEPTION_IF_NULL(real_node);
if (!real_node->isa<CNode>()) {
return;
}
auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
MS_EXCEPTION_IF_NULL(prim_node);
if (!IsValueNode<Primitive>(prim_node)) {
return;
}
ValuePtr value = GetValueNode(prim_node);
MS_EXCEPTION_IF_NULL(value);
auto prim = value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
auto iter = primitive_vars.find(prim);
if (iter == primitive_vars.end()) {
return;
}
(*equiv)[iter->second] = real_node;
}
EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
int svar_index = GetSVarStartIndex(values_pattern); int svar_index = GetSVarStartIndex(values_pattern);
if (svar_index == kInvalidVarIndex) { if (svar_index == kInvalidVarIndex) {
return nullptr; return nullptr;
@ -229,12 +257,12 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR
if (svar_index != -1 && i == IntToSize(svar_index)) { if (svar_index != -1 && i == IntToSize(svar_index)) {
auto seq = auto seq =
std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
equiv = Match(values_pattern[svar_index], seq, equiv); equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
} else { } else {
if (svar_index != -1 && i > IntToSize(svar_index)) { if (svar_index != -1 && i > IntToSize(svar_index)) {
expr_i = i + diff - 1; expr_i = i + diff - 1;
} }
equiv = Match(values_pattern[i], values_expr[expr_i], equiv); equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
} }
if (equiv == nullptr) { if (equiv == nullptr) {
return nullptr; return nullptr;
@ -243,7 +271,8 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR
return equiv; return equiv;
} }
EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) const { EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
EquivPtr equiv) const {
MS_LOG(DEBUG) << "-----[in Match]"; MS_LOG(DEBUG) << "-----[in Match]";
MS_LOG(DEBUG) << "GetVar w"; MS_LOG(DEBUG) << "GetVar w";
BaseRef pattern_ref = GetVar(pattern); BaseRef pattern_ref = GetVar(pattern);
@ -292,10 +321,12 @@ EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, Equiv
// 6. if any svar in both side, find the SeqVar index, // 6. if any svar in both side, find the SeqVar index,
// try to pack the Var s in std::vector to a Seq and match elements one by one. // try to pack the Var s in std::vector to a Seq and match elements one by one.
// check svar // check svar
return AlignSVar(values_pattern, values_expr, equiv); equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
return equiv;
} }
BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) const { BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
MS_LOG(DEBUG) << "-----[in Replace]"; MS_LOG(DEBUG) << "-----[in Replace]";
BaseRef ref = GetVar(pattern); BaseRef ref = GetVar(pattern);
@ -304,7 +335,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co
// w is var // w is var
if (utils::isa<VarPtr>(ref)) { if (utils::isa<VarPtr>(ref)) {
const VarPtr& var = utils::cast<VarPtr>(ref); const VarPtr &var = utils::cast<VarPtr>(ref);
auto iter = equiv->find(var); auto iter = equiv->find(var);
if (iter != equiv->end()) { if (iter != equiv->end()) {
out = iter->second; out = iter->second;
@ -316,7 +347,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co
} }
// visitor to visit the list // visitor to visit the list
std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef& u) { return Replace(u, equiv); }; std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); };
visitor_->SetFn(fn); visitor_->SetFn(fn);
BaseRef visit_out; BaseRef visit_out;

@ -31,6 +31,7 @@
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>
#include <list> #include <list>
#include <utility>
#include "pre_activate/common/visit.h" #include "pre_activate/common/visit.h"
#include "ir/base.h" #include "ir/base.h"
@ -44,16 +45,19 @@ using CondVarPtr = std::shared_ptr<CondVar>;
using SVarPtr = std::shared_ptr<SeqVar>; using SVarPtr = std::shared_ptr<SeqVar>;
const int kInvalidVarIndex = -2; const int kInvalidVarIndex = -2;
using ConditionFunc = std::function<bool(const BaseRef&)>; using ConditionFunc = std::function<bool(const BaseRef &)>;
// Base wildcard variable which could match any anf node. // Base wildcard variable which could match any anf node.
class Var : public Base { class Var : public Base {
friend class VarHasher; friend class VarHasher;
public: public:
explicit Var(const std::string& tag = "") : tag_(tag) { EnsureTag(); } explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
Var(const Var& other) : Base(other), tag_(other.tag_) {} explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
virtual Var& operator=(const Var& other) { EnsureTag();
}
Var(const Var &other) : Base(other), tag_(other.tag_) {}
virtual Var &operator=(const Var &other) {
if (&other == this) { if (&other == this) {
return *this; return *this;
} }
@ -63,12 +67,13 @@ class Var : public Base {
~Var() override = default; ~Var() override = default;
MS_DECLARE_PARENT(Var, Base); MS_DECLARE_PARENT(Var, Base);
virtual bool matches(const BaseRef&) { return true; } virtual bool matches(const BaseRef &) { return true; }
virtual bool operator==(const Var& other) const { return tag_ == other.tag_; } virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
bool operator!=(const Var& other) const { return !(&other == this); } bool operator!=(const Var &other) const { return !(&other == this); }
std::string tag() const { return tag_; } std::string tag() const { return tag_; }
PrimitivePtr primitive() const { return primitive_; }
std::string ToString() const override { std::string ToString() const override {
std::ostringstream buffer; std::ostringstream buffer;
buffer << "Var(" << tag_ << ")"; buffer << "Var(" << tag_ << ")";
@ -80,12 +85,13 @@ class Var : public Base {
void EnsureTag(); void EnsureTag();
std::string tag_; std::string tag_;
PrimitivePtr primitive_;
}; };
// VarNode means variable node, a subclass of AnfNode // VarNode means variable node, a subclass of AnfNode
class VarNode : public AnfNode { class VarNode : public AnfNode {
public: public:
VarNode(const VarPtr& value, const FuncGraphPtr& func_graph) : AnfNode(func_graph), var_(value) {} VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
~VarNode() override = default; ~VarNode() override = default;
MS_DECLARE_PARENT(VarNode, AnfNode); MS_DECLARE_PARENT(VarNode, AnfNode);
@ -95,16 +101,16 @@ using VarNodePtr = std::shared_ptr<VarNode>;
class VarHasher { class VarHasher {
public: public:
std::size_t operator()(const Var& var) const { return var.hash(); } std::size_t operator()(const Var &var) const { return var.hash(); }
}; };
// Condition Var, match an anf node when condition function return true. // Condition Var, match an anf node when condition function return true.
class CondVar : public Var { class CondVar : public Var {
public: public:
explicit CondVar(const ConditionFunc& cond) : cond_fn_(cond) {} explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
~CondVar() override = default; ~CondVar() override = default;
MS_DECLARE_PARENT(CondVar, Var); MS_DECLARE_PARENT(CondVar, Var);
bool matches(const BaseRef& value) override { bool matches(const BaseRef &value) override {
MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
if (utils::isa<Var>(value)) { if (utils::isa<Var>(value)) {
return false; return false;
@ -124,55 +130,60 @@ class SeqVar : public Var {
~SeqVar() override = default; ~SeqVar() override = default;
MS_DECLARE_PARENT(SeqVar, Var); MS_DECLARE_PARENT(SeqVar, Var);
explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; }
bool matches(const BaseRef& value) override { bool matches(const BaseRef &value) override {
// match Seq. // match Seq.
if (utils::isa<Seq>(value)) { if (utils::isa<Seq>(value)) {
const Seq& seq = utils::cast<Seq>(value); const Seq &seq = utils::cast<Seq>(value);
return std::all_of(seq.begin(), seq.end(), [this](const BaseRef& v) { return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
auto eq = subvar_->matches(v); auto eq = subvar_->matches(v);
return eq; return eq;
}); });
} }
return false; return false;
} }
bool operator==(const SeqVar& other) const { return *subvar_ == *other.subvar_; } bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
std::string ToString() const override; std::string ToString() const override;
private: private:
VarPtr subvar_; VarPtr subvar_;
}; };
bool operator==(const VarPtr& lhs, const VarPtr& rhs); bool operator==(const VarPtr &lhs, const VarPtr &rhs);
inline bool operator!=(const VarPtr& lhs, const VarPtr& rhs) { return !(lhs == rhs); } inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
std::ostream& operator<<(std::ostream& os, const VarPtr& var); std::ostream &operator<<(std::ostream &os, const VarPtr &var);
using Equiv = std::map<VarPtr, BaseRef>; using Equiv = std::map<VarPtr, BaseRef>;
using EquivPtr = std::shared_ptr<Equiv>; using EquivPtr = std::shared_ptr<Equiv>;
using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>;
using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
inline bool DefaultTypeEq(const BaseRef& x, const BaseRef& y) { return x.type() == y.type(); } inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
class PatternEngine { class PatternEngine {
public: public:
PatternEngine(const std::shared_ptr<Visitor>& visitor, const std::function<bool(const BaseRef&, const BaseRef&)>& eq, PatternEngine(const std::shared_ptr<Visitor> &visitor,
const std::function<bool(const BaseRef&, const BaseRef&)>& type_eq = DefaultTypeEq) const std::function<bool(const BaseRef &, const BaseRef &)> &eq,
const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq)
: visitor_(visitor), eq_(eq), type_eq_(type_eq) {} : visitor_(visitor), eq_(eq), type_eq_(type_eq) {}
~PatternEngine() = default; ~PatternEngine() = default;
EquivPtr Match(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) const; EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
EquivPtr equiv) const;
// Replace pattern with equivalent // Replace pattern with equivalent
BaseRef Replace(const BaseRef& pattern, const EquivPtr& equiv) const; BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const;
private: private:
EquivPtr AlignSVar(const VectorRef& values_pattern, const VectorRef& values_expr, EquivPtr equiv) const; EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
bool ToVector(const BaseRef& pattern, const BaseRef& expr, VectorRef* const values_pattern, const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
VectorRef* const values_expr) const; bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
bool ToVector(const VectorRef& pattern_ref, const VectorRef& expr_ref, VectorRef* const values_pattern, VectorRef *const values_expr) const;
VectorRef* const values_expr) const; bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
VectorRef *const values_expr) const;
std::shared_ptr<Visitor> visitor_; std::shared_ptr<Visitor> visitor_;
std::function<bool(const BaseRef&, const BaseRef&)> eq_; std::function<bool(const BaseRef &, const BaseRef &)> eq_;
std::function<bool(const BaseRef&, const BaseRef&)> type_eq_; std::function<bool(const BaseRef &, const BaseRef &)> type_eq_;
}; };
} // namespace mindspore } // namespace mindspore
namespace std { namespace std {

@ -40,6 +40,7 @@ class TestMatchEngine : public UT::Common {
public: public:
PatternEngine TU; PatternEngine TU;
EquivPtr equiv_null; EquivPtr equiv_null;
PrimitiveVarMap primitive_vars_null;
}; };
TEST_F(TestMatchEngine, Var) { TEST_F(TestMatchEngine, Var) {
@ -106,30 +107,30 @@ TEST_F(TestMatchEngine, MatchRaw_Var) {
// common // common
equiv_null->clear(); equiv_null->clear();
d = TU.Match(v1, 1, equiv_null); d = TU.Match(v1, 1, primitive_vars_null, equiv_null);
ASSERT_EQ((*d)[v1], 1); ASSERT_EQ((*d)[v1], 1);
equiv_null->clear(); equiv_null->clear();
(*equiv_null)[v1] = v2; (*equiv_null)[v1] = v2;
d = TU.Match(v1, 1, equiv_null); d = TU.Match(v1, 1, primitive_vars_null, equiv_null);
ASSERT_EQ(d->count(v2), std::size_t(1)); ASSERT_EQ(d->count(v2), std::size_t(1));
ASSERT_EQ((*d)[v2], 1); ASSERT_EQ((*d)[v2], 1);
equiv_null->clear(); equiv_null->clear();
(*equiv_null)[v1] = v2; (*equiv_null)[v1] = v2;
(*equiv_null)[v3] = 1; (*equiv_null)[v3] = 1;
d = TU.Match(v1, 1, equiv_null); d = TU.Match(v1, 1, primitive_vars_null, equiv_null);
ASSERT_EQ(d->count(v2), std::size_t(1)); ASSERT_EQ(d->count(v2), std::size_t(1));
ASSERT_EQ((*d)[v2], 1); ASSERT_EQ((*d)[v2], 1);
equiv_null->clear(); equiv_null->clear();
d = TU.Match(VectorRef({v1}), VectorRef({1}), equiv_null); d = TU.Match(VectorRef({v1}), VectorRef({1}), primitive_vars_null, equiv_null);
ASSERT_EQ(d->size(), std::size_t(1)); ASSERT_EQ(d->size(), std::size_t(1));
ASSERT_EQ(d->count(v1), std::size_t(1)); ASSERT_EQ(d->count(v1), std::size_t(1));
ASSERT_EQ((*d)[v1], 1); ASSERT_EQ((*d)[v1], 1);
equiv_null->clear(); equiv_null->clear();
ASSERT_EQ(TU.Match(1, 2, equiv_null), nullptr); ASSERT_EQ(TU.Match(1, 2, primitive_vars_null, equiv_null), nullptr);
} }
TEST_F(TestMatchEngine, MatchRaw_SVar) { TEST_F(TestMatchEngine, MatchRaw_SVar) {
@ -139,22 +140,22 @@ TEST_F(TestMatchEngine, MatchRaw_SVar) {
EquivPtr d; EquivPtr d;
equiv_null->clear(); equiv_null->clear();
d = TU.Match(VectorRef({sv1}), VectorRef({1, 2}), equiv_null); d = TU.Match(VectorRef({sv1}), VectorRef({1, 2}), primitive_vars_null, equiv_null);
ASSERT_EQ(d->size(), std::size_t(1)); ASSERT_EQ(d->size(), std::size_t(1));
ASSERT_EQ(d->count(sv1), std::size_t(1)); ASSERT_EQ(d->count(sv1), std::size_t(1));
ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({1, 2})); ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({1, 2}));
equiv_null->clear(); equiv_null->clear();
d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 2}), equiv_null); d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 2}), primitive_vars_null, equiv_null);
ASSERT_EQ(d->size(), std::size_t(2)); ASSERT_EQ(d->size(), std::size_t(2));
ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({2})); ASSERT_EQ(utils::cast<Seq>((*d)[sv1]), Seq({2}));
equiv_null->clear(); equiv_null->clear();
ASSERT_EQ(TU.Match(VectorRef({sv1, sv2}), VectorRef({1, 2}), equiv_null), nullptr); ASSERT_EQ(TU.Match(VectorRef({sv1, sv2}), VectorRef({1, 2}), primitive_vars_null, equiv_null), nullptr);
equiv_null->clear(); equiv_null->clear();
(*equiv_null)[sv1] = std::make_shared<Seq>(PatternListType{1, 2}); (*equiv_null)[sv1] = std::make_shared<Seq>(PatternListType{1, 2});
d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 1, 2}), equiv_null); d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 1, 2}), primitive_vars_null, equiv_null);
ASSERT_EQ(d->size(), std::size_t(2)); ASSERT_EQ(d->size(), std::size_t(2));
ASSERT_EQ((*d)[v1], 1); ASSERT_EQ((*d)[v1], 1);
} }
@ -167,13 +168,13 @@ TEST_F(TestMatchEngine, Match) {
EquivPtr d; EquivPtr d;
equiv_null->clear(); equiv_null->clear();
d = TU.Match(VectorRef({v1, v1, v2}), VectorRef({1, 1, 2}), equiv_null); d = TU.Match(VectorRef({v1, v1, v2}), VectorRef({1, 1, 2}), primitive_vars_null, equiv_null);
ASSERT_EQ(d->size(), std::size_t(2)); ASSERT_EQ(d->size(), std::size_t(2));
ASSERT_EQ((*d)[v1], 1); ASSERT_EQ((*d)[v1], 1);
ASSERT_EQ((*d)[v2], 2); ASSERT_EQ((*d)[v2], 2);
equiv_null->clear(); equiv_null->clear();
d = TU.Match(static_cast<int>(1), static_cast<float>(1), equiv_null); d = TU.Match(static_cast<int>(1), static_cast<float>(1), primitive_vars_null, equiv_null);
ASSERT_EQ(d, nullptr); ASSERT_EQ(d, nullptr);
} }
@ -197,18 +198,19 @@ TEST_F(TestMatchEngine, Match_CondVar) {
EquivPtr d; EquivPtr d;
equiv_null->clear(); equiv_null->clear();
d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), -1}), equiv_null); d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), -1}), primitive_vars_null, equiv_null);
ASSERT_GE(d->size(), std::size_t(0)); ASSERT_GE(d->size(), std::size_t(0));
auto vfn = (*d)[vf]; auto vfn = (*d)[vf];
ASSERT_EQ((*d)[vf], static_cast<float>(1.0)); ASSERT_EQ((*d)[vf], static_cast<float>(1.0));
ASSERT_EQ((*d)[vn], -1); ASSERT_EQ((*d)[vn], -1);
equiv_null->clear(); equiv_null->clear();
d = TU.Match(VectorRef({vf, vn}), VectorRef({1, static_cast<float>(-1.0)}), equiv_null); d = TU.Match(VectorRef({vf, vn}), VectorRef({1, static_cast<float>(-1.0)}), primitive_vars_null, equiv_null);
ASSERT_EQ(d, nullptr); ASSERT_EQ(d, nullptr);
equiv_null->clear(); equiv_null->clear();
d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), static_cast<int>(1)}), equiv_null); d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast<float>(1.0), static_cast<int>(1)}), primitive_vars_null,
equiv_null);
ASSERT_EQ(d, nullptr); ASSERT_EQ(d, nullptr);
} }

Loading…
Cancel
Save