|
|
|
@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
|
|
|
|
|
// Create pattern.
|
|
|
|
|
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
|
|
|
|
|
|
|
|
|
|
PDNode* x =
|
|
|
|
|
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
|
|
|
|
|
|
|
|
|
|
multihead_pattern(x);
|
|
|
|
|
multihead_pattern();
|
|
|
|
|
// Create New OpDesc
|
|
|
|
|
auto fuse_creater = [&](
|
|
|
|
|
Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
|
|
|
|
|
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
|
|
|
|
|
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
|
|
|
|
|
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
|
|
|
|
|
Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
|
|
|
|
@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
|
|
|
|
@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
|
|
|
|
|
multihead_pattern);
|
|
|
|
|
|
|
|
|
|
fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
|
|
|
|
|
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
|
|
|
|
|
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
|
|
|
|
|
reshape2_qkv_out, scale, scale_out);
|
|
|
|
|
|
|
|
|
@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
|
|
|
|
|
return fusion_count;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
|
|
|
|
|
// Create shared nodes.
|
|
|
|
|
auto* layer_norm = pattern->NewNode(layer_norm_repr());
|
|
|
|
|
|
|
|
|
|
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr());
|
|
|
|
|
layer_norm_out_var->assert_is_op_input("mul");
|
|
|
|
|
PDNode* MultiHeadMatmulPattern::operator()() {
|
|
|
|
|
auto* input0 = pattern->NewNode(input0_repr());
|
|
|
|
|
input0->assert_is_op_input("mul");
|
|
|
|
|
|
|
|
|
|
// First path with scale
|
|
|
|
|
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
|
|
|
|
@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
|
|
|
|
|
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
|
|
|
|
|
"matmul"); // link to matmul qkv
|
|
|
|
|
|
|
|
|
|
// Link all nodes together
|
|
|
|
|
layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var});
|
|
|
|
|
// Q path
|
|
|
|
|
mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var});
|
|
|
|
|
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
|
|
|
|
|
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
|
|
|
|
|
|
|
|
|
|
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
|
|
|
|
|
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
|
|
|
|
|
scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var});
|
|
|
|
|
// K path
|
|
|
|
|
mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var});
|
|
|
|
|
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
|
|
|
|
|
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
|
|
|
|
|
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
|
|
|
|
|
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
|
|
|
|
@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
|
|
|
|
|
.LinksTo({eltadd_qk_out_var});
|
|
|
|
|
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
|
|
|
|
|
// V path
|
|
|
|
|
mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var});
|
|
|
|
|
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
|
|
|
|
|
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
|
|
|
|
|
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
|
|
|
|
|
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
|
|
|
|
@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
// Create pattern.
|
|
|
|
|
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
|
|
|
|
|
|
|
|
|
|
PDNode* x =
|
|
|
|
|
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
|
|
|
|
|
|
|
|
|
|
multihead_pattern(x);
|
|
|
|
|
multihead_pattern();
|
|
|
|
|
// Create New OpDesc
|
|
|
|
|
auto fuse_creater = [&](
|
|
|
|
|
Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
|
|
|
|
|
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
|
|
|
|
|
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
|
|
|
|
|
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
|
|
|
|
|
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
|
|
|
|
@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
|
|
|
|
|
auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]});
|
|
|
|
|
|
|
|
|
|
// create a new var in scope
|
|
|
|
|
VarDesc combined_w_desc(
|
|
|
|
|
patterns::PDNodeName(name_scope, "multi_head_combined_weight"));
|
|
|
|
|
combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
|
|
|
|
|
combined_w_desc.SetDataType(wq_tensor->type());
|
|
|
|
|
combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel());
|
|
|
|
|
combined_w_desc.SetPersistable(true);
|
|
|
|
|
|
|
|
|
|
// create a new var in scope
|
|
|
|
|
VarDesc combined_bias_desc(
|
|
|
|
|
patterns::PDNodeName(name_scope, "multi_head_combined_bias"));
|
|
|
|
|
combined_bias_desc.SetShape({3, bq_tensor->dims()[0]});
|
|
|
|
|
combined_bias_desc.SetDataType(bq_tensor->type());
|
|
|
|
|
combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel());
|
|
|
|
|
combined_bias_desc.SetPersistable(true);
|
|
|
|
|
|
|
|
|
|
auto* combined_w_node = graph->CreateVarNode(&combined_w_desc);
|
|
|
|
|
auto* combined_w_tensor =
|
|
|
|
|
scope->Var(combined_w_node->Name())->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
combined_w_tensor->Resize(combined_w_dims);
|
|
|
|
|
auto* combined_w_data =
|
|
|
|
|
combined_w_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
|
|
|
|
|
auto* combined_w_desc = mul0_w->Var();
|
|
|
|
|
combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
|
|
|
|
|
combined_w_desc->SetPersistable(true);
|
|
|
|
|
|
|
|
|
|
auto* combined_bias_desc = eltadd0_b->Var();
|
|
|
|
|
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
|
|
|
|
|
combined_bias_desc->SetPersistable(true);
|
|
|
|
|
|
|
|
|
|
framework::LoDTensor tmp_combined_w_tensor;
|
|
|
|
|
tmp_combined_w_tensor.Resize(combined_w_dims);
|
|
|
|
|
auto* tmp_combined_w_data =
|
|
|
|
|
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
|
|
|
|
|
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
|
|
|
|
|
// Combine the three fc weights together.
|
|
|
|
@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
for (int k = 0; k < dims_w; k++) {
|
|
|
|
|
int out_index = i * (3 * dims_w) + j * dims_w + k;
|
|
|
|
|
int in_index = i * dims_w + k;
|
|
|
|
|
combined_w_data[out_index] = w_vec[j][in_index];
|
|
|
|
|
tmp_combined_w_data[out_index] = w_vec[j][in_index];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()});
|
|
|
|
|
auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc);
|
|
|
|
|
auto* combined_bias_tensor =
|
|
|
|
|
scope->Var(combined_bias_node->Name())->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
combined_bias_tensor->Resize(combined_bias_dims);
|
|
|
|
|
auto* combined_bias_data =
|
|
|
|
|
combined_bias_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
wq_tensor->Resize(combined_w_dims);
|
|
|
|
|
auto* new_combined_w_data =
|
|
|
|
|
wq_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
memcpy(new_combined_w_data, tmp_combined_w_data,
|
|
|
|
|
sizeof(float) * wq_tensor->numel());
|
|
|
|
|
|
|
|
|
|
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
|
|
|
|
|
|
|
|
|
|
framework::LoDTensor tmp_combined_bias_tensor;
|
|
|
|
|
tmp_combined_bias_tensor.Resize(combined_bias_dims);
|
|
|
|
|
auto* tmp_combined_bias_data =
|
|
|
|
|
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
size_t bias_size = bq_tensor->numel();
|
|
|
|
|
memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size);
|
|
|
|
|
memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
|
|
|
|
|
memcpy(combined_bias_data + 2 * bias_size, bv_data,
|
|
|
|
|
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
|
|
|
|
|
memcpy(tmp_combined_bias_data + bias_size, bk_data,
|
|
|
|
|
sizeof(float) * bias_size);
|
|
|
|
|
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
|
|
|
|
|
sizeof(float) * bias_size);
|
|
|
|
|
|
|
|
|
|
scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()});
|
|
|
|
|
bq_tensor->Resize(combined_bias_dims);
|
|
|
|
|
auto* new_combined_bias_data =
|
|
|
|
|
bq_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
memcpy(new_combined_bias_data, tmp_combined_bias_data,
|
|
|
|
|
sizeof(float) * bq_tensor->numel());
|
|
|
|
|
|
|
|
|
|
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
|
|
|
|
|
|
|
|
|
|
auto reshape_desc = reshape2->Op();
|
|
|
|
|
int head_number =
|
|
|
|
@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
OpDesc multihead_op_desc;
|
|
|
|
|
multihead_op_desc.SetType("multihead_matmul");
|
|
|
|
|
|
|
|
|
|
multihead_op_desc.SetInput("Input", {layer_norm_out->Name()});
|
|
|
|
|
multihead_op_desc.SetInput("W", {combined_w_node->Name()});
|
|
|
|
|
multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()});
|
|
|
|
|
multihead_op_desc.SetInput("Input", {input0->Name()});
|
|
|
|
|
multihead_op_desc.SetInput("W", {mul0_w->Name()});
|
|
|
|
|
multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
|
|
|
|
|
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
|
|
|
|
|
|
|
|
|
|
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
|
|
|
|
@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
|
|
|
|
|
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(layer_norm_out, multihead);
|
|
|
|
|
IR_NODE_LINK_TO(combined_w_node, multihead);
|
|
|
|
|
IR_NODE_LINK_TO(combined_bias_node, multihead);
|
|
|
|
|
IR_NODE_LINK_TO(input0, multihead);
|
|
|
|
|
IR_NODE_LINK_TO(mul0_w, multihead);
|
|
|
|
|
IR_NODE_LINK_TO(eltadd0_b, multihead);
|
|
|
|
|
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
|
|
|
|
@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
|
|
|
|
|
multihead_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
|
|
|
|
@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
|
|
|
|
|
multihead_pattern);
|
|
|
|
|
|
|
|
|
|
fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
|
|
|
|
|
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
|
|
|
|
|
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
|
|
|
|
|
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
|
|
|
|
|
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
|
|
|
|
|
reshape2_0, reshape2_qkv_out, scale, scale_out);
|
|
|
|
|
|
|
|
|
|
std::unordered_set<const Node*> marked_nodes({eltadd0,
|
|
|
|
|
eltadd1,
|
|
|
|
|
eltadd2,
|
|
|
|
|
eltadd0_b,
|
|
|
|
|
eltadd1_b,
|
|
|
|
|
eltadd2_b,
|
|
|
|
|
eltadd0_out,
|
|
|
|
@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
|
|
|
|
|
mul0_out,
|
|
|
|
|
mul1_out,
|
|
|
|
|
mul2_out,
|
|
|
|
|
mul0_w,
|
|
|
|
|
mul1_w,
|
|
|
|
|
mul2_w,
|
|
|
|
|
reshape2_qkv,
|
|
|
|
|