oneDNN BatchNorm + Act fusion pass. (#27912)

revert-27871-prv-conv-grad-opt
Adam Osewski 4 years ago committed by GitHub
parent fb7f85291b
commit 7db747d9e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -110,6 +110,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
@ -151,6 +152,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context)
if (WITH_GPU)
set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv)

@ -1188,6 +1188,26 @@ PDNode *patterns::BatchNormActGrad::operator()(
return bn_grad;
}
PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) {
auto *bn_x = pattern->NewNode(bn_in_repr())
->AsInput()
->assert_is_op_input("batch_norm", "X");
auto *bn = pattern->NewNode(batch_norm_repr())->assert_is_op("batch_norm");
auto *bn_out = pattern->NewNode(bn_out_repr())
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input(act_type);
auto *act =
pattern->NewNode(act_repr())->assert_is_op(act_type)->AsIntermediate();
auto *act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type, "Out")
->AsOutput();
bn->LinksFrom({bn_x}).LinksTo({bn_out});
act->LinksFrom({bn_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) {

@ -664,6 +664,27 @@ struct BatchNormActGrad : public PatternBase {
PATTERN_DECL_NODE(d_bn_bias);
};
//
// \brief Pattern looking for batch_norm and a directly following activation
// operator.
//
// \note Currently only ReLU is supported as an activation function.
// Formula: act(bn(x))
// Op: batch_norm + act
struct BatchNormActOneDNN : public PatternBase {
BatchNormActOneDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bn_act_onednn") {}
PDNode* operator()(const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(bn_in);
PATTERN_DECL_NODE(batch_norm);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(bn_out);
PATTERN_DECL_NODE(act_out);
};
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// op: elementwise_add + act

@ -0,0 +1,108 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseBatchNormActOneDNNPass::ApplyImpl(Graph *graph) const {
std::string act_type("relu");
FuseBatchNormAct(graph, act_type);
}
void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument(
"The input graph of "
"FuseBatchNormActOneDNNPass should not be nullptr."));
FusePassBase::Init("bn_act", graph);
GraphPatternDetector gpd;
patterns::BatchNormActOneDNN bn_act_pattern(gpd.mutable_pattern(), "bn_act");
bn_act_pattern(act_type);
int found_bn_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse BatchNorm with ReLU activation op.";
// BN output
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern);
auto *bn_op = batch_norm->Op();
if (bn_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("use_mkldnn")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when oneDNN library "
"is used."));
}
if (bn_op->HasAttr("trainable_statistics")) {
PADDLE_ENFORCE(
!BOOST_GET_CONST(bool, bn_op->GetAttr("trainable_statistics")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only when mean and variance "
"are not calculated by current batch statistics."));
}
if (bn_op->HasAttr("is_test")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, bn_op->GetAttr("is_test")),
platform::errors::PreconditionNotMet(
"The BatchNorm+Act fusion may happen only during inference."));
}
bn_op->SetAttr("use_mkldnn", true);
bn_op->SetAttr("is_test", true);
bn_op->SetAttr("fuse_with_relu", true);
bn_op->SetAttr("trainable_statistics", false);
bn_op->SetOutput("Y", {act_out->Name()});
IR_OP_VAR_LINK(batch_norm, act_out);
GraphSafeRemoveNodes(g, {act, bn_out});
found_bn_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_act_count);
PrettyLogDetail("--- fused %d batch norm with relu activation",
found_bn_act_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(batch_norm_act_fuse_pass,
paddle::framework::ir::FuseBatchNormActOneDNNPass);
REGISTER_PASS_CAPABILITY(batch_norm_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("batch_norm", 0)
.EQ("relu", 0));

@ -0,0 +1,44 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* \brief Fuse the BatchNorm and activation operators into single OneDNN's
* BatchNorm with post-op.
*
* \note Currently only ReLU is supported as an activation function.
*/
class FuseBatchNormActOneDNNPass : public FusePassBase {
public:
virtual ~FuseBatchNormActOneDNNPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
void FuseBatchNormAct(ir::Graph *graph, const std::string &act_types) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -207,6 +207,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
"batch_norm_act_fuse_pass",
"mkldnn_inplace_pass", // This pass should be activated after
// fuses
})) {

@ -0,0 +1,79 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of batch norm and activation."""
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from inference_pass_test import InferencePassTest
from paddle import enable_static
from paddle.fluid.core import PassVersionChecker
enable_static()
class BnReluOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
bn_out = fluid.layers.batch_norm(
input=data, is_test=True, use_global_stats=self.global_stats)
relu_out = fluid.layers.relu(bn_out)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [relu_out]
self.enable_mkldnn = True
def set_params(self):
self.global_stats = False
self.pass_name = "batch_norm_act_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class BnReluGlobalStatsOneDnnFusePassTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
bn_out = fluid.layers.batch_norm(
input=data, is_test=True, use_global_stats=self.global_stats)
relu_out = fluid.layers.relu(bn_out)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [relu_out]
self.enable_mkldnn = True
def set_params(self):
self.global_stats = True
self.pass_name = "batch_norm_act_fuse_pass"
def test_check_output(self):
self.check_output()
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save