commit
329a8c5283
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
|
||||
|
||||
std::unordered_set<Node*> nodes2delete;
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* conv_input = gpd.mutable_pattern()
|
||||
->NewNode("conv_relu_mkldnn_fuse/conv_input")
|
||||
->AsInput()
|
||||
->assert_is_op_input("conv2d", "Input");
|
||||
patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(),
|
||||
"conv_relu_mkldnn_fuse");
|
||||
conv_relu_pattern(conv_input);
|
||||
|
||||
int found_conv_relu_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle ConvReLU fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
|
||||
conv_relu_pattern); // Filter
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
|
||||
|
||||
// Create an ConvReLU Node.
|
||||
OpDesc desc;
|
||||
std::string conv_relu_i_in = subgraph.at(conv_input)->Name();
|
||||
std::string conv_relu_w_in = conv_weight->Name();
|
||||
std::string conv_relu_b_in = conv_bias->Name();
|
||||
std::string conv_relu_out = relu_out->Name();
|
||||
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
|
||||
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
|
||||
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
|
||||
desc.SetOutput("Out", std::vector<std::string>({conv_relu_out}));
|
||||
desc.SetType("conv2d");
|
||||
for (auto& attr : conv->Op()->GetAttrMap()) {
|
||||
desc.SetAttr(attr.first, attr.second);
|
||||
}
|
||||
desc.SetAttr("fuse_relu", true);
|
||||
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied.
|
||||
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
|
||||
|
||||
PADDLE_ENFORCE(subgraph.count(conv_input));
|
||||
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node);
|
||||
IR_NODE_LINK_TO(conv_weight, conv_relu_node);
|
||||
IR_NODE_LINK_TO(conv_bias, conv_relu_node);
|
||||
IR_NODE_LINK_TO(conv_relu_node, relu_out);
|
||||
|
||||
found_conv_relu_count++;
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
|
||||
AddStatis(found_conv_relu_count);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ConvReLUFusePass);
|
@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Fuse the CONV and ReLU to a ConvReLUOp.
|
||||
*/
|
||||
class ConvReLUFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvReLUFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,108 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
if (type == "conv2d") {
|
||||
op->SetAttr("use_mkldnn", true);
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetInput("Filter", {inputs[1]});
|
||||
op->SetInput("Bias", {inputs[2]});
|
||||
} else if (type == "relu") {
|
||||
op->SetInput("X", inputs);
|
||||
}
|
||||
op->SetOutput("Out", outputs);
|
||||
}
|
||||
|
||||
// a->OP0->b
|
||||
// b->OP1->c
|
||||
// (c, weights, bias)->conv->f
|
||||
// (f)->relu->g
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
for (auto& v :
|
||||
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::SELECTED_ROWS);
|
||||
if (v == "weights" || v == "bias") {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
|
||||
std::vector<std::string>({"b"}));
|
||||
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
|
||||
std::vector<std::string>({"c"}));
|
||||
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}),
|
||||
std::vector<std::string>({"f"}));
|
||||
SetOp(&prog, "relu", std::vector<std::string>({"f"}),
|
||||
std::vector<std::string>({"g"}));
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
TEST(ConvReLUFusePass, basic) {
|
||||
auto prog = BuildProgramDesc();
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass");
|
||||
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
|
||||
graph = pass->Apply(std::move(graph));
|
||||
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
// Remove 3 Nodes: CONV, RELU, conv_out
|
||||
// Add 1 Node: ConvReLU
|
||||
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
|
||||
|
||||
// Assert conv_relu op in newly generated graph
|
||||
int conv_relu_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == "conv2d") {
|
||||
if (node->Op()->HasAttr("use_mkldnn")) {
|
||||
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
|
||||
if (use_mkldnn) {
|
||||
if (node->Op()->HasAttr("fuse_relu")) {
|
||||
bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu"));
|
||||
if (fuse_relu) {
|
||||
++conv_relu_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(conv_relu_count, 1);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(conv_relu_mkldnn_fuse_pass);
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue