@ -89,9 +89,63 @@ struct MultiHeadMatmulPattern : public PatternBase {
PATTERN_DECL_NODE ( matmul_qkv ) ;
PATTERN_DECL_NODE ( matmul_qkv_out ) ;
} ;
struct MultiHeadMatmulV3Pattern : public PatternBase {
MultiHeadMatmulV3Pattern ( PDPattern * pattern , const std : : string & name_scope )
: PatternBase ( pattern , name_scope , " multihead_matmul_v3 " ) { }
PDNode * operator ( ) ( ) ;
// declare operator node's name
PATTERN_DECL_NODE ( input0 ) ;
PATTERN_DECL_NODE ( mul0 ) ;
PATTERN_DECL_NODE ( mul1 ) ;
PATTERN_DECL_NODE ( mul2 ) ;
PATTERN_DECL_NODE ( mul0_w ) ;
PATTERN_DECL_NODE ( mul1_w ) ;
PATTERN_DECL_NODE ( mul2_w ) ;
PATTERN_DECL_NODE ( mul0_out ) ;
PATTERN_DECL_NODE ( mul1_out ) ;
PATTERN_DECL_NODE ( mul2_out ) ;
PATTERN_DECL_NODE ( eltadd0 ) ; // ELEMENTWISE_ADD
PATTERN_DECL_NODE ( eltadd1 ) ; // ELEMENTWISE_ADD
PATTERN_DECL_NODE ( eltadd2 ) ; // ELEMENTWISE_ADD
PATTERN_DECL_NODE ( eltadd0_b ) ; // ELEMENTWISE_ADD
PATTERN_DECL_NODE ( eltadd1_b ) ; // ELEMENTWISE_ADD
PATTERN_DECL_NODE ( eltadd2_b ) ; // ELEMENTWISE_ADD
PATTERN_DECL_NODE ( eltadd0_out ) ;
PATTERN_DECL_NODE ( eltadd1_out ) ;
PATTERN_DECL_NODE ( eltadd2_out ) ;
PATTERN_DECL_NODE ( reshape2_0 ) ;
PATTERN_DECL_NODE ( reshape2_1 ) ;
PATTERN_DECL_NODE ( reshape2_2 ) ;
PATTERN_DECL_NODE ( reshape2_qkv ) ;
PATTERN_DECL_NODE ( reshape2_0_out ) ;
PATTERN_DECL_NODE ( reshape2_1_out ) ;
PATTERN_DECL_NODE ( reshape2_2_out ) ;
PATTERN_DECL_NODE ( reshape2_qkv_out ) ;
PATTERN_DECL_NODE ( transpose2_0 ) ;
PATTERN_DECL_NODE ( transpose2_1 ) ;
PATTERN_DECL_NODE ( transpose2_2 ) ;
PATTERN_DECL_NODE ( transpose2_qkv ) ;
PATTERN_DECL_NODE ( transpose2_0_out ) ;
PATTERN_DECL_NODE ( transpose2_1_out ) ;
PATTERN_DECL_NODE ( transpose2_2_out ) ;
PATTERN_DECL_NODE ( transpose2_qkv_out ) ;
PATTERN_DECL_NODE ( matmul_qk ) ;
PATTERN_DECL_NODE ( matmul_qk_out ) ;
PATTERN_DECL_NODE ( eltadd_qk ) ;
PATTERN_DECL_NODE ( eltadd_qk_b ) ;
PATTERN_DECL_NODE ( eltadd_qk_out ) ;
PATTERN_DECL_NODE ( softmax_qk ) ;
PATTERN_DECL_NODE ( softmax_qk_out ) ;
PATTERN_DECL_NODE ( matmul_qkv ) ;
PATTERN_DECL_NODE ( matmul_qkv_out ) ;
} ;
} // namespace patterns
// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op.
class MultiHeadMatmulFusePass : public FusePassBase {
public :
virtual ~ MultiHeadMatmulFusePass ( ) { }
@ -112,6 +166,16 @@ class MultiHeadMatmulV2FusePass : public FusePassBase {
const std : : string name_scope_ { " multihead_matmul_fuse_v2 " } ;
} ;
class MultiHeadMatmulV3FusePass : public FusePassBase {
public :
virtual ~ MultiHeadMatmulV3FusePass ( ) { }
protected :
void ApplyImpl ( Graph * graph ) const ;
const std : : string name_scope_ { " multihead_matmul_fuse_v3 " } ;
} ;
} // namespace ir
} // namespace framework
} // namespace paddle