|
|
@ -615,6 +615,16 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
|
|
|
|
multihead_pattern);
|
|
|
|
multihead_pattern);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
|
|
|
|
|
|
|
|
// patterns, we do not support this kind of fusion, this pass will not take
|
|
|
|
|
|
|
|
// effect.
|
|
|
|
|
|
|
|
bool is_fc_params_shared =
|
|
|
|
|
|
|
|
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
|
|
|
|
|
|
|
|
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
|
|
|
|
|
|
|
|
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
|
|
|
|
|
|
|
|
if (is_fc_params_shared) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
|
|
|
|
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
|
|
|
|
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
|
|
|
|
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
|
|
|
|
reshape2_0, reshape2_qkv_out, scale, scale_out);
|
|
|
|
reshape2_0, reshape2_qkv_out, scale, scale_out);
|
|
|
|