|
|
|
@ -36,8 +36,9 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
|
|
|
|
|
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
|
|
|
|
|
const std::string& arg) {
|
|
|
|
|
PDNode* node = pattern->NewNode(name)
|
|
|
|
|
->assert_is_op_output("lookup_table")
|
|
|
|
|
->assert_is_op_input("elementwise_add", arg);
|
|
|
|
|
->assert_is_only_output_of_op("lookup_table")
|
|
|
|
|
->assert_is_op_input("elementwise_add", arg)
|
|
|
|
|
->AsIntermediate();
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
void Embedding2Eltwise1Pattern::operator()() {
|
|
|
|
@ -94,7 +95,8 @@ void SkipLayerNorm::operator()() {
|
|
|
|
|
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
|
|
|
|
|
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
|
|
|
|
|
->assert_is_op_output("elementwise_add")
|
|
|
|
|
->assert_is_op_input("layer_norm", "X");
|
|
|
|
|
->assert_is_op_input("layer_norm", "X")
|
|
|
|
|
->AsIntermediate();
|
|
|
|
|
auto* layer_norm =
|
|
|
|
|
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
|
|
|
|
|
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
|
|
|
|
|