|
|
|
@ -27,82 +27,12 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
std::tuple<CNodePtr, CNodePtr, AnfNodePtr> GetSharedNodesByPattern(const AnfNodePtr &node) {
|
|
|
|
|
auto add3_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kAddInputNum);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(add3_cnode);
|
|
|
|
|
auto real_div2_cnode = CheckAnfNodeIfCNodeAndInputSize(add3_cnode->input(1), kMulInputNum);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_div2_cnode);
|
|
|
|
|
auto real_div0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(1), kRealDivInputNum);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_div0_cnode);
|
|
|
|
|
auto sqrt0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(2), kSqrtInputNum);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sqrt0_cnode);
|
|
|
|
|
auto add2_cnode = CheckAnfNodeIfCNodeAndInputSize(sqrt0_cnode->input(1), kAddInputNum);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(add2_cnode);
|
|
|
|
|
auto real_div1_cnode = CheckAnfNodeIfCNodeAndInputSize(add2_cnode->input(1), kRealDivInputNum);
|
|
|
|
|
auto constant_add2_y = add2_cnode->input(2);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(real_div0_cnode, real_div1_cnode, constant_add2_y);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MatchRealDiv4(const AnfNodePtr &real_div4, const AnfNodePtr &real_div1, const AnfNodePtr &constant_add2_y) {
|
|
|
|
|
if (real_div4 == nullptr || !real_div4->isa<CNode>()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto real_div4_cnode = real_div4->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_div4_cnode);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(real_div4_cnode) != kRealDivOpName || real_div4_cnode->inputs().size() < kRealDivInputNum) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr add4_cnode = nullptr;
|
|
|
|
|
if (!CheckIfCNodeAndInputSize(real_div4_cnode->input(2), kAddInputNum, &add4_cnode) ||
|
|
|
|
|
AnfAlgo::GetCNodeName(add4_cnode) != prim::kPrimTensorAdd->name()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
CNodePtr sqrt1_cnode = nullptr;
|
|
|
|
|
if (!CheckIfCNodeAndInputSize(add4_cnode->input(1), kSqrtInputNum, &sqrt1_cnode) ||
|
|
|
|
|
AnfAlgo::GetCNodeName(sqrt1_cnode) != kSqrtOpName) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(add4_cnode->input(2));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(constant_add2_y);
|
|
|
|
|
return sqrt1_cnode->input(1) == real_div1 && *(add4_cnode->input(2)) == *constant_add2_y;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const BaseRef LambNextMVRule::DefinePattern() const {
|
|
|
|
|
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
|
|
|
|
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_deal_div);
|
|
|
|
|
|
|
|
|
|
auto mul0 = VectorRef({prim::kPrimMul, input_varptr_[7], input_varptr_[4]});
|
|
|
|
|
auto mul1 = VectorRef({prim::kPrimMul, input_varptr_[8], input_varptr_[3]});
|
|
|
|
|
auto mul2 = VectorRef({prim::kPrimMul, input_varptr_[9], input_varptr_[1]});
|
|
|
|
|
auto mul3 = VectorRef({prim::kPrimMul, input_varptr_[10], input_varptr_[0]});
|
|
|
|
|
auto mul4 = VectorRef({prim::kPrimMul, input_varptr_[11], input_varptr_[6]});
|
|
|
|
|
auto add0 = VectorRef({prim::kPrimTensorAdd, mul0, mul1});
|
|
|
|
|
auto add1 = VectorRef({prim::kPrimTensorAdd, mul2, mul3});
|
|
|
|
|
|
|
|
|
|
auto real_div0 = VectorRef({prim_deal_div, add0, input_varptr_[5]});
|
|
|
|
|
auto real_div1 = VectorRef({prim_deal_div, add1, input_varptr_[2]});
|
|
|
|
|
|
|
|
|
|
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, input_varptr_[12]});
|
|
|
|
|
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
|
|
|
|
auto real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
|
|
|
|
|
|
|
|
|
|
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
|
|
|
|
|
std::vector<AnfNodePtr> *old_pattern_outputs) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
CNodePtr real_div0 = nullptr;
|
|
|
|
|
CNodePtr real_div1 = nullptr;
|
|
|
|
|
AnfNodePtr constant_add2_y = nullptr;
|
|
|
|
|
std::tie(real_div0, real_div1, constant_add2_y) = GetSharedNodesByPattern(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
|
auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_);
|
|
|
|
|
auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_);
|
|
|
|
|
|
|
|
|
|
auto manager = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
@ -112,19 +42,17 @@ bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNode
|
|
|
|
|
}
|
|
|
|
|
AnfNodeIndexSet real_div0_outputs = users[real_div0];
|
|
|
|
|
auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(),
|
|
|
|
|
[&node, &real_div1, &constant_add2_y](const std::pair<AnfNodePtr, int> &node_index) {
|
|
|
|
|
return node_index.first != node && node_index.second == 1 &&
|
|
|
|
|
MatchRealDiv4(node_index.first, real_div1, constant_add2_y);
|
|
|
|
|
[&real_div2, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
|
|
|
|
|
return node_index.first != real_div2 && node_index.second == 1 &&
|
|
|
|
|
MatchAnotherPattern(node_index.first, equiv);
|
|
|
|
|
});
|
|
|
|
|
if (iter == real_div0_outputs.end()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto add0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div0->input(1), kAddInputNum);
|
|
|
|
|
auto add1_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div1->input(1), kAddInputNum);
|
|
|
|
|
(*old_pattern_outputs).push_back(node);
|
|
|
|
|
(*old_pattern_outputs).push_back(add0_cnode);
|
|
|
|
|
(*old_pattern_outputs).push_back(add1_cnode);
|
|
|
|
|
(*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_));
|
|
|
|
|
(*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_));
|
|
|
|
|
(*old_pattern_outputs).push_back(iter->first);
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
@ -136,8 +64,19 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
auto prim = std::make_shared<Primitive>(kLambNextMVOpName);
|
|
|
|
|
std::vector<AnfNodePtr> lamb_next_mv_rule_inputs = {NewValueNode(prim)};
|
|
|
|
|
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(lamb_next_mv_rule_inputs),
|
|
|
|
|
[&equiv](const VarPtr &in) { return utils::cast<AnfNodePtr>((*equiv)[in]); });
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input0_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input1_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input2_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input3_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input4_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input5_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input6_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul0_x_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul1_sub_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul2_x_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul3_sub1_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul4_x_]));
|
|
|
|
|
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[add2_y_]));
|
|
|
|
|
auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(lamb_next_mv_rule);
|
|
|
|
|
|
|
|
|
@ -162,14 +101,60 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
|
|
|
|
|
return lamb_next_mv_rule_outputs[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
|
|
|
|
|
return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) &&
|
|
|
|
|
IsSameNode(equiv1, equiv2, add2_y_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &equiv) const {
|
|
|
|
|
std::vector<AnfNodePtr> old_pattern_outputs;
|
|
|
|
|
if (!IsRuleMatched(func_graph, node, &old_pattern_outputs)) {
|
|
|
|
|
if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef LambNextMVRuleCond4::DefinePattern() const {
|
|
|
|
|
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
|
|
|
|
|
|
|
|
|
auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
|
|
|
|
|
auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
|
|
|
|
|
auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
|
|
|
|
|
auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
|
|
|
|
|
auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
|
|
|
|
|
auto add0 = VectorRef({add0_var_, mul0, mul1});
|
|
|
|
|
auto add1 = VectorRef({add1_var_, mul2, mul3});
|
|
|
|
|
|
|
|
|
|
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
|
|
|
|
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
|
|
|
|
|
|
|
|
|
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
|
|
|
|
|
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
|
|
|
|
auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
|
|
|
|
|
|
|
|
|
|
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
|
|
|
|
|
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
|
|
|
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_real_div);
|
|
|
|
|
VarPtr Xs = std::make_shared<SeqVar>();
|
|
|
|
|
VarPtr Ys = std::make_shared<SeqVar>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(Xs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(Ys);
|
|
|
|
|
// Two patterns share: real_div0, real_div1, add2_y_
|
|
|
|
|
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
|
|
|
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
|
|
|
|
|
|
|
|
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
|
|
|
|
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
|
|
|
|
|
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
|
|
|
|
return real_div4;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|