[Ernie GPU Optim]: Fuse three fc to multihtead matmul (#22486)

* 1. optim multihead matmul: fuse three fc to multihtead matmul

test=develop

* fix conflict
test=develop

* fix comments
test=develop
revert-22710-feature/integrated_ps_api
Zhaolong Xing 5 years ago committed by GitHub
parent a8dd425aa3
commit 8acd745c25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -32,8 +32,6 @@ struct MultiHeadMatmulPattern : public PatternBase {
PDNode* operator()(PDNode* x); PDNode* operator()(PDNode* x);
// declare operator node's name // declare operator node's name
// PATTERN_DECL_NODE(dropout);
// PATTERN_DECL_NODE(dropout_out);
PATTERN_DECL_NODE(layer_norm); PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_out); PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(mul0); PATTERN_DECL_NODE(mul0);
@ -79,8 +77,6 @@ struct MultiHeadMatmulPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
// PATTERN_DECL_NODE(dropout_qk);
// PATTERN_DECL_NODE(dropout_qk_out);
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out); PATTERN_DECL_NODE(matmul_qkv_out);
@ -98,6 +94,16 @@ class MultiHeadMatmulFusePass : public FusePassBase {
const std::string name_scope_{"multihead_matmul_fuse"}; const std::string name_scope_{"multihead_matmul_fuse"};
}; };
class MultiHeadMatmulV2FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV2FusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v2"};
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -17,6 +17,27 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "weights0", {768, 768});
AddVarToScope(param_scope, "weights1", {768, 768});
AddVarToScope(param_scope, "weights2", {768, 768});
AddVarToScope(param_scope, "bias_0", {768});
AddVarToScope(param_scope, "bias_1", {768});
AddVarToScope(param_scope, "bias_2", {768});
AddVarToScope(param_scope, "biasqk", {768});
AddVarToScope(param_scope, "weightsl", {768, 768});
return param_scope;
}
TEST(MultiHeadMatmulFusePass, basic) { TEST(MultiHeadMatmulFusePass, basic) {
// inputs operator output // inputs operator output
// -------------------------------------------------------------------- // --------------------------------------------------------------------
@ -87,7 +108,10 @@ TEST(MultiHeadMatmulFusePass, basic) {
layers.mul(reshape_qkv_out, weights_l); layers.mul(reshape_qkv_out, weights_l);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass"); graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass_v2");
if (pass.get() == nullptr) LOG(INFO) << "asdfasdf";
int num_nodes_before = graph->Nodes().size(); int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
@ -96,8 +120,17 @@ TEST(MultiHeadMatmulFusePass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "multihead_matmul"); int num_fused_nodes_after = GetNumOpNodes(graph, "multihead_matmul");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 29); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); num_nodes_before, num_nodes_after + 39,
platform::errors::InvalidArgument(
"After the multihead_matmul pass, The node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 39, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1,
platform::errors::InvalidArgument(
"After the multihead_matmul pass, there should be one "
"multihead_matmul op, but the result is %d",
num_fused_nodes_after));
} }
} // namespace ir } // namespace ir
@ -105,3 +138,4 @@ TEST(MultiHeadMatmulFusePass, basic) {
} // namespace paddle } // namespace paddle
USE_PASS(multihead_matmul_fuse_pass); USE_PASS(multihead_matmul_fuse_pass);
USE_PASS(multihead_matmul_fuse_pass_v2);

@ -107,7 +107,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"multihead_matmul_fuse_pass", "multihead_matmul_fuse_pass_v2",
"fc_fuse_pass", // "fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be

@ -15,126 +15,80 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MultiHeadMatMulOp : public framework::OperatorWithKernel { class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext *context) const override { void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("Q"), true,
"Input(Q) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("K"), true,
"Input(K) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("V"), true,
"Input(V) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasQ"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasK"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasV"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasQK"), true,
"Input(BiasQK) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
"Output(Out) of MatMulOp should not be null.");
auto dim_q = context->GetInputDim("Q");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
auto dim_k = context->GetInputDim("K");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
auto dim_v = context->GetInputDim("V");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
PADDLE_ENFORCE_EQ(dim_q[0], dim_k[0],
"Multihead input should have same batch size");
PADDLE_ENFORCE_EQ(dim_q[0], dim_v[0],
"Multihead input should have same batch size");
PADDLE_ENFORCE_EQ(dim_q[1], dim_k[1],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[1], dim_v[1],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[2], dim_k[2],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[2], dim_v[2],
"Multihead input should have same size");
auto dim_bias_q = context->GetInputDim("BiasQ");
PADDLE_ENFORCE_GT(dim_bias_q.size(), 0,
"Multihead input should be at least 1-D tensor.");
auto dim_bias_k = context->GetInputDim("BiasK");
PADDLE_ENFORCE_GT(dim_bias_k.size(), 0,
"Multihead input should be at least 1-D tensor.");
auto dim_bias_v = context->GetInputDim("BiasV");
PADDLE_ENFORCE_GT(dim_bias_v.size(), 0,
"Multihead input should be at least 1-D tensor.");
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_k[0],
"Multihead input bias should have same batch size");
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0],
"Multihead input bias should have same batch size");
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
"Multihead input bias qk should be at least 4-D tensor.");
int b_indx = dim_bias_q.size() - 1;
int indx = dim_q.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_bias_q[b_indx], dim_q[indx], context->HasInput("Input"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"bias_q's last dim size should equal to" "Input(Input) of MultiHeadMatMul should not be null."));
" q last dim size, but received bias_q's size is:%d q is:%d", PADDLE_ENFORCE_EQ(context->HasInput("W"), true,
dim_bias_q[b_indx], dim_q[indx])); platform::errors::InvalidArgument(
"Input(W) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_bias_k[b_indx], dim_k[indx], context->HasInput("Bias"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"bias_k's last dim size should equal to" "Input(Bias) of MultiHeadMatMul should not be null."));
" k last dim size, but received bias_k's size is:%d k is:%d",
dim_bias_k[b_indx], dim_k[indx]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_bias_v[b_indx], dim_v[indx], context->HasInput("BiasQK"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"bias_v's last dim size should equal to" "Input(BiasQK) of MultiHeadMatMul should not be null."));
" v last dim size, but received bias_v's size is:%d v is:%d", PADDLE_ENFORCE_EQ(
dim_bias_v[b_indx], dim_v[indx])); context->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0], auto dim_w = context->GetInputDim("W");
platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"q should have same batch size" dim_w.size(), 2,
"with bias_qk, but received q's batch size is:%d " platform::errors::InvalidArgument(
"bias_qk's batch size is:%d", "Multihead input is expected at least a 3-D tensor, but "
dim_q[0], dim_bias_qk[0])); "it's %d-D tensor now.",
dim_w.size()));
int head_number = context->Attrs().Get<int>("head_number"); auto dim_bias_q = context->GetInputDim("Bias");
PADDLE_ENFORCE_GT(head_number, 1, PADDLE_ENFORCE_GT(
"Multihead input head number should be at least 1."); dim_bias_q.size(), 1,
platform::errors::InvalidArgument(
"Multihead input should be at least 2-D tensor, but it's "
"%d-D tensor now.",
dim_bias_q.size()));
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(
dim_bias_qk.size(), 3,
platform::errors::InvalidArgument(
"Multihead input bias qk should be at least 4-D tensor, "
"but it's %d-D tensor now.",
dim_bias_qk.size()));
context->SetOutputDim("Out", dim_q); int head_number = context->Attrs().Get<int>("head_number");
context->ShareLoD("Q", /*->*/ "Out"); PADDLE_ENFORCE_GT(
head_number, 1,
platform::errors::InvalidArgument(
"Multihead input head number should be at least 1, but it %d now.",
head_number));
// modify this
auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out");
} }
}; };
class MultiHeadMatMulOpMaker : public framework::OpProtoAndCheckerMaker { class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Q", "The first input of MultiHeadMatMul op"); AddInput("Input", "The input of MultiHeadMatMul op");
AddInput("K", "The second input of MMultiHeadMatMul op"); AddInput("W", "The weight input of MultiHeadMatMul op");
AddInput("V", "The third input of MultiHeadMatMul op"); AddInput("Bias", "The bias input of MultiHeadMatMul op");
AddInput("BiasQ", "The first bias input of MultiHeadMatMul op");
AddInput("BiasK", "The second bias input of MultiHeadMatMul op");
AddInput("BiasV", "The third bias input of MultiHeadMatMul op");
AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op"); AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op");
AddOutput("Out", "The output of MultiHeadMatMul op"); AddOutput("Out", "The output of MultiHeadMatMul op");
AddAttr<bool>("transpose_Q", AddAttr<bool>("transpose_Q",
@ -161,10 +115,6 @@ Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of B Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
Both the input `Q` and `K` can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input `Q`, because
they are the same.
)DOC"); )DOC");
} }
}; };
@ -173,5 +123,5 @@ they are the same.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulOp, REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulV2Op,
ops::MultiHeadMatMulOpMaker); ops::MultiHeadMatMulV2OpMaker);

File diff suppressed because it is too large Load Diff

@ -47,12 +47,21 @@ class TestFusedMultiheadMatmulOp(OpTest):
self.config() self.config()
h = self.seq_len h = self.seq_len
w = self.head_number * self.size_per_head w = self.head_number * self.size_per_head
self.Q = np.random.random((self.batch_size, h, w)).astype("float32") self.Input = np.random.random(
self.K = np.random.random((self.batch_size, h, w)).astype("float32") (self.batch_size, h, w)).astype("float32") - 0.5
self.V = np.random.random((self.batch_size, h, w)).astype("float32") self.WQ = np.random.random((w, w)).astype("float32")
self.KQ = np.random.random((w, w)).astype("float32")
self.VQ = np.random.random((w, w)).astype("float32")
self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape(
(w, 3, w))
self.Q = np.dot(self.Input, self.WQ)
self.K = np.dot(self.Input, self.KQ)
self.V = np.dot(self.Input, self.VQ)
self.BiasQ = np.random.random((1, w)).astype("float32") self.BiasQ = np.random.random((1, w)).astype("float32")
self.BiasK = np.random.random((1, w)).astype("float32") self.BiasK = np.random.random((1, w)).astype("float32")
self.BiasV = np.random.random((1, w)).astype("float32") self.BiasV = np.random.random((1, w)).astype("float32")
self.CombinedB = np.vstack((self.BiasQ, self.BiasK, self.BiasV))
self.BiasQK = np.random.random( self.BiasQK = np.random.random(
(self.batch_size, self.head_number, self.seq_len, (self.batch_size, self.head_number, self.seq_len,
self.seq_len)).astype("float32") self.seq_len)).astype("float32")
@ -84,12 +93,9 @@ class TestFusedMultiheadMatmulOp(OpTest):
reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w)) reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w))
self.inputs = { self.inputs = {
"Q": self.Q, "Input": self.Input,
"K": self.K, "W": self.CombinedW,
"V": self.V, "Bias": self.CombinedB,
"BiasQ": self.BiasQ,
"BiasK": self.BiasK,
"BiasV": self.BiasV,
"BiasQK": self.BiasQK "BiasQK": self.BiasQK
} }
self.attrs = { self.attrs = {

Loading…
Cancel
Save