|
|
|
|
@ -88,6 +88,11 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
|
|
|
|
|
}
|
|
|
|
|
auto total_node_count = input_node_indexes.size() + output_node_indexes.size();
|
|
|
|
|
size_t half_count = total_node_count / 2;
|
|
|
|
|
if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) {
|
|
|
|
|
if (node->primitive->value.AsActivation()->type == schema::ActivationType_LEAKY_RELU) {
|
|
|
|
|
return has_trans_count >= half_count;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (total_node_count % 2 == 0) {
|
|
|
|
|
can_fusion = has_trans_count > half_count;
|
|
|
|
|
} else {
|
|
|
|
|
|