|
|
|
@ -51,7 +51,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
|
|
|
|
|
|
|
|
|
|
auto next_op = [=](Node* x, const std::string& op_type) -> Node* {
|
|
|
|
|
if (!(x && x->IsVar())) {
|
|
|
|
|
return false;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
for (auto* op : x->outputs) {
|
|
|
|
|
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
|
|
|
|
@ -63,7 +63,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
|
|
|
|
|
|
|
|
|
|
auto get_op_input_var = [=](Node* x, const std::string& arg_name) -> Node* {
|
|
|
|
|
if (!(x && x->IsOp())) {
|
|
|
|
|
return false;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
for (auto* var : x->inputs) {
|
|
|
|
|
for (auto name : x->Op()->Input(arg_name)) {
|
|
|
|
@ -93,10 +93,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
|
|
|
|
|
if (!next_is_matmul_from_arg) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto* sub_x = squared_x->outputs[0]->outputs[0];
|
|
|
|
|
return var_is_op_input(sub_x, "elementwise_sub", "X") &&
|
|
|
|
|
sub_x->outputs[0]->outputs.size() == 1 &&
|
|
|
|
|
var_is_op_input(sub_x->outputs[0]->outputs[0], "elementwise_mul");
|
|
|
|
|
auto* sub_y_in = squared_x->outputs[0]->outputs[0];
|
|
|
|
|
return var_is_op_input(sub_y_in, "elementwise_sub", "Y") &&
|
|
|
|
|
sub_y_in->outputs[0]->outputs.size() == 1 &&
|
|
|
|
|
var_is_op_input(sub_y_in->outputs[0]->outputs[0], "elementwise_mul");
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto is_fusion_first_mul_out = [=](Node* x) -> bool {
|
|
|
|
@ -120,10 +120,10 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
|
|
|
|
|
if (!next_is_square) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto* sub_y = x->outputs[0]->outputs[0];
|
|
|
|
|
return var_is_op_input(sub_y, "elementwise_sub", "Y") &&
|
|
|
|
|
sub_y->outputs[0]->outputs.size() == 1 &&
|
|
|
|
|
var_is_op_input(sub_y->outputs[0]->outputs[0], "elementwise_mul");
|
|
|
|
|
auto* sub_x_in = x->outputs[0]->outputs[0];
|
|
|
|
|
return var_is_op_input(sub_x_in, "elementwise_sub", "X") &&
|
|
|
|
|
sub_x_in->outputs[0]->outputs.size() == 1 &&
|
|
|
|
|
var_is_op_input(sub_x_in->outputs[0]->outputs[0], "elementwise_mul");
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto* x = pattern->NewNode(
|
|
|
|
@ -219,7 +219,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
|
|
|
|
|
if (!is_sub_op) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto* matmul_sqx_sqy_var = get_op_input_var(x, "X");
|
|
|
|
|
auto* matmul_sqx_sqy_var = get_op_input_var(x, "Y");
|
|
|
|
|
return is_fusion_mat_squared_x_y_op_out(matmul_sqx_sqy_var);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -280,7 +280,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
|
|
|
|
|
matmul_squared_x_y_op->LinksFrom({squared_x, squared_y})
|
|
|
|
|
.LinksTo({mat_squared_x_y_op_out});
|
|
|
|
|
square_matmuled_xy_op->LinksFrom({matmuled_xy}).LinksTo({squared_xmuly});
|
|
|
|
|
sub_op->LinksFrom({mat_squared_x_y_op_out, squared_xmuly})
|
|
|
|
|
sub_op->LinksFrom({squared_xmuly, mat_squared_x_y_op_out})
|
|
|
|
|
.LinksTo({sub_op_out});
|
|
|
|
|
constant_op->LinksFrom({}).LinksTo({constant_op_out});
|
|
|
|
|
elementmul_op->LinksFrom({constant_op_out, sub_op_out})
|
|
|
|
|