|
|
|
@ -1,4 +1,20 @@
|
|
|
|
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
|
|
|
|
#include <functional>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_traits.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -8,8 +24,7 @@ namespace patterns {
|
|
|
|
|
|
|
|
|
|
struct Pattern : public PatternBase {
|
|
|
|
|
Pattern(PDPattern* pattern, const std::string& name_scope)
|
|
|
|
|
: PatternBase{pattern, name_scope, ""}
|
|
|
|
|
{ }
|
|
|
|
|
: PatternBase{pattern, name_scope, ""} {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::string name_scope() { return name_scope_; }
|
|
|
|
@ -39,20 +54,16 @@ struct Conv {
|
|
|
|
|
|
|
|
|
|
std::function<PDNode*()> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&]() -> PDNode* {
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())
|
|
|
|
|
->assert_is_op("conv2d");
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())->assert_is_op("conv2d");
|
|
|
|
|
|
|
|
|
|
auto input_var = pattern->new_node(input_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
input_name());
|
|
|
|
|
->assert_is_op_input(op_name(), input_name());
|
|
|
|
|
|
|
|
|
|
auto filter_var = pattern->new_node(filter_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
filter_name());
|
|
|
|
|
->assert_is_op_input(op_name(), filter_name());
|
|
|
|
|
|
|
|
|
|
auto output_var = pattern->new_node(output_name())
|
|
|
|
|
->assert_is_op_output(op_name(),
|
|
|
|
|
output_name());
|
|
|
|
|
->assert_is_op_output(op_name(), output_name());
|
|
|
|
|
|
|
|
|
|
conv_op->LinksFrom({input_var, filter_var});
|
|
|
|
|
conv_op->LinksTo({output_var});
|
|
|
|
@ -70,20 +81,17 @@ struct ElementwiseAdd {
|
|
|
|
|
|
|
|
|
|
std::function<PDNode*(PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&](PDNode* conv_output) -> PDNode* {
|
|
|
|
|
auto elementwise_add_op = pattern->new_node(op_name())
|
|
|
|
|
->assert_is_op("elementwise_add");
|
|
|
|
|
auto elementwise_add_op =
|
|
|
|
|
pattern->new_node(op_name())->assert_is_op("elementwise_add");
|
|
|
|
|
|
|
|
|
|
auto x_var = pattern->new_node(x_name())
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
x_name());
|
|
|
|
|
auto x_var =
|
|
|
|
|
pattern->new_node(x_name())->assert_is_op_input(op_name(), x_name());
|
|
|
|
|
|
|
|
|
|
conv_output->assert_is_op_input(op_name(),
|
|
|
|
|
y_name());
|
|
|
|
|
conv_output->assert_is_op_input(op_name(), y_name());
|
|
|
|
|
|
|
|
|
|
auto out_var = pattern->new_node(out_name())
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output(op_name(),
|
|
|
|
|
out_name());
|
|
|
|
|
->assert_is_op_output(op_name(), out_name());
|
|
|
|
|
|
|
|
|
|
elementwise_add_op->LinksFrom({x_var, conv_output});
|
|
|
|
|
elementwise_add_op->LinksTo({out_var});
|
|
|
|
@ -111,8 +119,7 @@ void LinkNodes(Node* from, Node* to) {
|
|
|
|
|
|
|
|
|
|
template <typename IT, typename FindFunc, typename ReplaceFunc>
|
|
|
|
|
void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
|
|
|
|
|
if (s == e)
|
|
|
|
|
return;
|
|
|
|
|
if (s == e) return;
|
|
|
|
|
|
|
|
|
|
auto it = std::find_if(s, e, f);
|
|
|
|
|
|
|
|
|
@ -126,8 +133,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
|
|
|
|
|
|
|
|
|
|
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
for (auto& node : GraphTraits::DFS(*graph)) {
|
|
|
|
|
auto same = std::find_if(std::begin(node.inputs),
|
|
|
|
|
std::end(node.inputs),
|
|
|
|
|
auto same = std::find_if(std::begin(node.inputs), std::end(node.inputs),
|
|
|
|
|
[from](Node* n) { return n == from; });
|
|
|
|
|
|
|
|
|
|
if (same != std::end(node.inputs)) {
|
|
|
|
@ -137,10 +143,12 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
|
|
|
|
|
using input_type = VariableNameMap::value_type;
|
|
|
|
|
|
|
|
|
|
ReplaceAllOccurances(std::begin(inputs), std::end(inputs),
|
|
|
|
|
ReplaceAllOccurances(
|
|
|
|
|
std::begin(inputs), std::end(inputs),
|
|
|
|
|
[from](const input_type& i) -> bool {
|
|
|
|
|
auto params = i.second;
|
|
|
|
|
auto pi = std::find_if(std::begin(params), std::end(params),
|
|
|
|
|
auto pi =
|
|
|
|
|
std::find_if(std::begin(params), std::end(params),
|
|
|
|
|
std::bind(std::equal_to<std::string>(),
|
|
|
|
|
from->Name(), std::placeholders::_1));
|
|
|
|
|
return pi != std::end(params);
|
|
|
|
@ -169,7 +177,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) {
|
|
|
|
|
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter,
|
|
|
|
|
Node* conv_output, Node* elementwise_add_x) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
@ -189,22 +198,23 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, conv_output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.op_name());
|
|
|
|
|
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.input_name());
|
|
|
|
|
auto conv_filter = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.filter_name());
|
|
|
|
|
auto conv_output = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.output_name());
|
|
|
|
|
|
|
|
|
|
auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.op_name());
|
|
|
|
|
auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.x_name());
|
|
|
|
|
auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.out_name());
|
|
|
|
|
auto conv_filter = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, conv_pattern.filter_name());
|
|
|
|
|
auto conv_output = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, conv_pattern.output_name());
|
|
|
|
|
|
|
|
|
|
auto elementwise_add_op = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, elementwise_add_pattern.op_name());
|
|
|
|
|
auto elementwise_add_x = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, elementwise_add_pattern.x_name());
|
|
|
|
|
auto elementwise_add_out = patterns::GetNodeFromSubgraph(
|
|
|
|
|
subgraph, pattern_ptr, elementwise_add_pattern.out_name());
|
|
|
|
|
|
|
|
|
|
fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
|
|
|
|
|
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
|
|
|
@ -219,4 +229,5 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
|
|
|
|
|
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
|
|
|
|
|
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
|
|
|
|
|