@ -34,15 +34,19 @@ namespace patterns {
static PDNode * create_emb_vars ( PDPattern * pattern , const std : : string & name ,
static PDNode * create_emb_vars ( PDPattern * pattern , const std : : string & name ,
const std : : string & arg ,
const std : : string & arg ,
bool is_persist = false ) {
bool is_persist = false ) {
std : : unordered_set < std : : string > embedding_ops { " lookup_table " ,
" lookup_table_v2 " } ;
PDNode * node =
PDNode * node =
pattern - > NewNode ( name ) - > assert_is_op_input ( " lookup_table " , arg ) ;
pattern - > NewNode ( name ) - > assert_is_op s_input( embedding_ops , arg ) ;
if ( is_persist ) return node - > assert_is_persistable_var ( ) ;
if ( is_persist ) return node - > assert_is_persistable_var ( ) ;
return node ;
return node ;
}
}
static PDNode * create_emb_out_vars ( PDPattern * pattern , const std : : string & name ,
static PDNode * create_emb_out_vars ( PDPattern * pattern , const std : : string & name ,
const std : : string & arg ) {
const std : : string & arg ) {
std : : unordered_set < std : : string > embedding_ops { " lookup_table " ,
" lookup_table_v2 " } ;
PDNode * node = pattern - > NewNode ( name )
PDNode * node = pattern - > NewNode ( name )
- > assert_is_only_output_of_op ( " lookup_table " )
- > assert_is_only_output_of_op s( embedding_ops )
- > assert_is_op_input ( " elementwise_add " , arg )
- > assert_is_op_input ( " elementwise_add " , arg )
- > AsIntermediate ( ) ;
- > AsIntermediate ( ) ;
return node ;
return node ;
@ -56,10 +60,12 @@ void Embedding2Eltwise1Pattern::operator()() {
create_emb_vars ( pattern , lookup_table1_w_repr ( ) , " W " , true ) ;
create_emb_vars ( pattern , lookup_table1_w_repr ( ) , " W " , true ) ;
auto * lookup_table2_w =
auto * lookup_table2_w =
create_emb_vars ( pattern , lookup_table2_w_repr ( ) , " W " , true ) ;
create_emb_vars ( pattern , lookup_table2_w_repr ( ) , " W " , true ) ;
std : : unordered_set < std : : string > embedding_ops { " lookup_table " ,
" lookup_table_v2 " } ;
auto * lookup_table1 =
auto * lookup_table1 =
pattern - > NewNode ( lookup_table1_repr ( ) ) - > assert_is_op ( " lookup_table " ) ;
pattern - > NewNode ( lookup_table1_repr ( ) ) - > assert_is_op s( embedding_ops ) ;
auto * lookup_table2 =
auto * lookup_table2 =
pattern - > NewNode ( lookup_table2_repr ( ) ) - > assert_is_op ( " lookup_table " ) ;
pattern - > NewNode ( lookup_table2_repr ( ) ) - > assert_is_op s( embedding_ops ) ;
auto * lookup_table1_out =
auto * lookup_table1_out =
create_emb_out_vars ( pattern , lookup_table1_out_repr ( ) , " X " ) ;
create_emb_out_vars ( pattern , lookup_table1_out_repr ( ) , " X " ) ;
auto * lookup_table2_out =
auto * lookup_table2_out =
@ -80,8 +86,10 @@ void Embedding1Eltwise1Pattern::operator()() {
create_emb_vars ( pattern , lookup_table1_x_repr ( ) , " Ids " ) ;
create_emb_vars ( pattern , lookup_table1_x_repr ( ) , " Ids " ) ;
auto * lookup_table1_w =
auto * lookup_table1_w =
create_emb_vars ( pattern , lookup_table1_w_repr ( ) , " W " , true ) ;
create_emb_vars ( pattern , lookup_table1_w_repr ( ) , " W " , true ) ;
std : : unordered_set < std : : string > embedding_ops { " lookup_table " ,
" lookup_table_v2 " } ;
auto * lookup_table1 =
auto * lookup_table1 =
pattern - > NewNode ( lookup_table1_repr ( ) ) - > assert_is_op ( " lookup_table " ) ;
pattern - > NewNode ( lookup_table1_repr ( ) ) - > assert_is_op s( embedding_ops ) ;
auto * lookup_table1_out =
auto * lookup_table1_out =
create_emb_out_vars ( pattern , lookup_table1_out_repr ( ) , " Y " ) ;
create_emb_out_vars ( pattern , lookup_table1_out_repr ( ) , " Y " ) ;
auto * eltwise_add =
auto * eltwise_add =
@ -347,4 +355,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
. AddCombination (
. AddCombination (
paddle : : framework : : compatible : : OpVersionComparatorCombination ( )
paddle : : framework : : compatible : : OpVersionComparatorCombination ( )
. EQ ( " lookup_table " , 0 )
. EQ ( " lookup_table " , 0 )
. LE ( " lookup_table_v2 " , 1 )
. EQ ( " elementweise_add " , 0 ) ) ;
. EQ ( " elementweise_add " , 0 ) ) ;