Enable the convolution/relu6(bounded_relu) fusion for FP32 on Intel platform. (#17130)
* Relu6 is the bottleneck op for Mobilenet-v2. As the mkldnn supports the conv/relu6 fusion, we implement it fusion via cpass way. Due to the int8 enabling for this fusion will be supported in MKLDNN v0.20, so this PR is focused on the fp32 optimization. Below table shows the benchmark(FPS) which measured on skx-8180(28 cores) Batch size | with fusion | without fusion -- | -- | -- 1 | 214.7 | 53.4 50 | 1219.727 | 137.280 test=develop * Fix the format issue test=develop * Add the missing nolint comments. test=develop * Fix the typos. test=develop * Register the conv_brelu_mkldnn_fuse_pass for the MKLDNN engine. test=develop * Adjust the indentation. test=develop * Add the test_conv_brelu_mkldnn_fuse_pass case. test=develop * Slightly update the code per Baidu comments. Let the parameter definition embedded into the code. That's will make the code easy to understand. test=developfix_ema
parent
3398f99608
commit
2281ebf0f3
@ -0,0 +1,71 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const {
|
||||
PADDLE_ENFORCE(graph);
|
||||
FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* conv_input = gpd.mutable_pattern()
|
||||
->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input")
|
||||
->AsInput()
|
||||
->assert_is_op_input("conv2d", "Input");
|
||||
patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(),
|
||||
"conv_bounded_relu_mkldnn_fuse");
|
||||
conv_brelu_pattern(conv_input);
|
||||
|
||||
int found_conv_brelu_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle ConvBoundedReLUFusePass fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
|
||||
conv_brelu_pattern); // Filter
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op
|
||||
GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out
|
||||
GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op
|
||||
|
||||
// Transform Conv node into ConvBReLU node.
|
||||
OpDesc* desc = conv->Op();
|
||||
desc->SetOutput("Output", std::vector<std::string>({brelu_out->Name()}));
|
||||
desc->SetAttr("fuse_brelu", true);
|
||||
desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold"));
|
||||
|
||||
GraphSafeRemoveNodes(graph, {brelu, conv_out});
|
||||
|
||||
PADDLE_ENFORCE(subgraph.count(conv_input));
|
||||
IR_NODE_LINK_TO(conv, brelu_out);
|
||||
found_conv_brelu_count++;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
|
||||
AddStatis(found_conv_brelu_count);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(conv_brelu_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ConvBReLUFusePass);
|
@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Fuse the CONV and ReLU6 to a ConvReLU6Op.
|
||||
*/
|
||||
class ConvBReLUFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvBReLUFusePass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,135 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
if (type == "conv2d") {
|
||||
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||
op->SetAttr("name", name);
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetInput("Filter", {inputs[1]});
|
||||
op->SetInput("Bias", {inputs[2]});
|
||||
} else if (type == "relu6") {
|
||||
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||
if (use_mkldnn) {
|
||||
op->SetAttr("threshold", 6.0f);
|
||||
}
|
||||
op->SetInput("X", inputs);
|
||||
}
|
||||
op->SetOutput("Out", outputs);
|
||||
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
|
||||
static_cast<int>(OpRole::kForward));
|
||||
}
|
||||
|
||||
// a->OP0->b
|
||||
// b->OP1->c
|
||||
// (c, weights, bias)->conv->f
|
||||
// (f)->brelu->g
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
for (auto& v :
|
||||
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
|
||||
"h", "weights2", "bias2", "k", "l"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::SELECTED_ROWS);
|
||||
if (v == "weights" || v == "bias") {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
|
||||
std::vector<std::string>({"b"}));
|
||||
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
|
||||
std::vector<std::string>({"c"}));
|
||||
// conv+brelu, both with MKL-DNN
|
||||
SetOp(&prog, "conv2d", "conv1",
|
||||
std::vector<std::string>({"c", "weights", "bias"}),
|
||||
std::vector<std::string>({"f"}), true);
|
||||
SetOp(&prog, "relu6", "relu1", std::vector<std::string>({"f"}),
|
||||
std::vector<std::string>({"g"}), true);
|
||||
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
|
||||
std::vector<std::string>({"h"}));
|
||||
// conv+brelu, only one with MKL-DNN
|
||||
SetOp(&prog, "conv2d", "conv2",
|
||||
std::vector<std::string>({"h", "weights2", "bias2"}),
|
||||
std::vector<std::string>({"k"}), true);
|
||||
SetOp(&prog, "relu6", "relu2", std::vector<std::string>({"k"}),
|
||||
std::vector<std::string>({"l"}));
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
TEST(ConvBReLUFusePass, basic) {
|
||||
auto prog = BuildProgramDesc();
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass");
|
||||
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
// Remove 3 Nodes: CONV, BRELU, conv_out
|
||||
// Add 1 Node: ConvBReLU
|
||||
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
|
||||
|
||||
// Assert conv_brelu op in newly generated graph
|
||||
int conv_brelu_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == "conv2d") {
|
||||
auto* op = node->Op();
|
||||
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
|
||||
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
|
||||
// check if only "conv1" convolution is fused
|
||||
auto op_name = boost::get<std::string>(op->GetAttr("name"));
|
||||
if (op_name == "conv1") {
|
||||
ASSERT_TRUE(op->HasAttr("fuse_brelu"));
|
||||
ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold"));
|
||||
|
||||
bool fuse_brelu = boost::get<bool>(op->GetAttr("fuse_brelu"));
|
||||
if (fuse_brelu) {
|
||||
++conv_brelu_count;
|
||||
float fuse_brelu_threshold =
|
||||
boost::get<float>(op->GetAttr("fuse_brelu_threshold"));
|
||||
EXPECT_EQ(fuse_brelu_threshold, 6.0f);
|
||||
}
|
||||
} else if (op_name == "conv2") {
|
||||
ASSERT_FALSE(op->HasAttr("fuse_brelu"));
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(conv_brelu_count, 1);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(conv_brelu_mkldnn_fuse_pass);
|
Loading…
Reference in new issue