|
|
|
@ -94,7 +94,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto* relu_op = x->inputs[0];
|
|
|
|
|
// std::cout << "xxxx" << std::endl;
|
|
|
|
|
bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 &&
|
|
|
|
|
relu_op->inputs[0]->IsVar() &&
|
|
|
|
|
VarLinksFromOp(relu_op->inputs[0], "fc") &&
|
|
|
|
@ -105,31 +104,18 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
|
|
|
|
|
}
|
|
|
|
|
auto* fc_op = relu_op->inputs[0]->inputs[0];
|
|
|
|
|
bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3;
|
|
|
|
|
// std::cout << "*****" << fc_op->inputs.size() << std::endl;
|
|
|
|
|
if (!is_fc) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (size_t kkk = 0; kkk < 3; ++kkk) {
|
|
|
|
|
// std::cout << "++++++" << kkk << std::endl;
|
|
|
|
|
if (!fc_op->inputs[kkk]->inputs.empty()) {
|
|
|
|
|
for (auto* fc_i : fc_op->inputs) {
|
|
|
|
|
if (!fc_i->inputs.empty()) {
|
|
|
|
|
if (at_top) {
|
|
|
|
|
return true;
|
|
|
|
|
} else {
|
|
|
|
|
bool res = VarLinksFromOp(fc_op->inputs[kkk], "relu");
|
|
|
|
|
// std::cout << fc_op->inputs[kkk]->Name() << "++++++-----" << kkk <<
|
|
|
|
|
// ":"
|
|
|
|
|
// << res << std::endl;
|
|
|
|
|
return res;
|
|
|
|
|
return VarLinksFromOp(fc_i, "relu");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// for (auto* fc_i : fc_op->inputs) {
|
|
|
|
|
// if (!fc_i->inputs.empty()) {
|
|
|
|
|
// std::cout << "++++++" << fc_op->inputs.size()<<std::endl;
|
|
|
|
|
|
|
|
|
|
// return VarLinksFromOp(fc_i, "relu");
|
|
|
|
|
// }
|
|
|
|
|
// }
|
|
|
|
|
return false;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -147,7 +133,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
|
|
|
|
|
Node* x, int repeated_times,
|
|
|
|
|
const std::string& act_type = "relu") -> bool {
|
|
|
|
|
for (int i = 0; i < repeated_times; ++i) {
|
|
|
|
|
// std::cout << "----" << i << std::endl;
|
|
|
|
|
if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -180,17 +165,9 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
|
|
|
|
|
x, std::max(1, num_fc - i - 1), "relu");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
bool part1 =
|
|
|
|
|
var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
|
|
|
|
|
x->inputs.size() > 0;
|
|
|
|
|
if (x->Name() == "fc_0.tmp_1" && x->IsVar() && part1) {
|
|
|
|
|
// std::cout << "testes" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
bool part2 = var_before_is_fc_act_repeated_n_times(x, i, "relu");
|
|
|
|
|
if (x->Name() == "fc_0.tmp_1") {
|
|
|
|
|
// std::cout << "========" << part1 << "," << part2 << std::endl;
|
|
|
|
|
}
|
|
|
|
|
return part1 && part2;
|
|
|
|
|
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
|
|
|
|
|
x->inputs.size() > 0 &&
|
|
|
|
|
var_before_is_fc_act_repeated_n_times(x, i, "relu");
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
name_scope + "/fc_in_" + std::to_string(i));
|
|
|
|
@ -394,7 +371,7 @@ std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl(
|
|
|
|
|
int fusion_count = 0;
|
|
|
|
|
for (int i = MAX_NUM_FC; i > 1; --i) {
|
|
|
|
|
fusion_count +=
|
|
|
|
|
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(3), 3);
|
|
|
|
|
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i);
|
|
|
|
|
}
|
|
|
|
|
AddStatis(fusion_count);
|
|
|
|
|
|
|
|
|
|