|
|
|
@ -129,7 +129,9 @@ static bool IsEqual(const std::vector<T> &x, const std::vector<T> &y) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(graph);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(graph,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Pointer to graph argument should not be NULL."));
|
|
|
|
|
FusePassBase::Init("fc_elementwise_layernorm_fuse", graph);
|
|
|
|
|
int found_subgraph_count = 0;
|
|
|
|
|
|
|
|
|
@ -203,12 +205,14 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
|
|
|
|
|
|
|
|
|
|
// outputs
|
|
|
|
|
new_desc.SetOutput("Out", {layer_norm_out->Name()});
|
|
|
|
|
if (layer_norm_mean->outputs.size() > 0U) {
|
|
|
|
|
bool lnm_has_output = layer_norm_mean->outputs.size() > 0U;
|
|
|
|
|
if (lnm_has_output) {
|
|
|
|
|
new_desc.SetOutput("Mean", {layer_norm_mean->Name()});
|
|
|
|
|
} else {
|
|
|
|
|
del_node_set.insert(layer_norm_mean);
|
|
|
|
|
}
|
|
|
|
|
if (layer_norm_variance->outputs.size() > 0U) {
|
|
|
|
|
bool lnv_has_output = layer_norm_variance->outputs.size() > 0U;
|
|
|
|
|
if (lnv_has_output) {
|
|
|
|
|
new_desc.SetOutput("Variance", {layer_norm_variance->Name()});
|
|
|
|
|
} else {
|
|
|
|
|
del_node_set.insert(layer_norm_variance);
|
|
|
|
@ -237,10 +241,10 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
|
|
|
|
|
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
|
|
|
|
|
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
|
|
|
|
|
IR_NODE_LINK_TO(fused_node, layer_norm_out);
|
|
|
|
|
if (layer_norm_mean->outputs.size() > 0U) {
|
|
|
|
|
if (lnm_has_output) {
|
|
|
|
|
IR_NODE_LINK_TO(fused_node, layer_norm_mean);
|
|
|
|
|
}
|
|
|
|
|
if (layer_norm_variance->outputs.size() > 0U) {
|
|
|
|
|
if (lnv_has_output) {
|
|
|
|
|
IR_NODE_LINK_TO(fused_node, layer_norm_variance);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|