Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into add-reshape-reuse-input
test=developrelease/1.1
commit
6447b69aec
@ -0,0 +1,137 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
template <typename BinaryOperation>
|
||||
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
|
||||
BinaryOperation f) {
|
||||
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
|
||||
LoDTensor vec_y;
|
||||
vec_y.Resize(vec_a.dims());
|
||||
const float* a = vec_a.data<float>();
|
||||
const float* b = vec_b.data<float>();
|
||||
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
|
||||
for (int i = 0; i < vec_a.numel(); i++) {
|
||||
y[i] = f(a[i], b[i]);
|
||||
}
|
||||
return vec_y;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
auto* scope = param_scope();
|
||||
PADDLE_ENFORCE(scope);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* conv_input =
|
||||
gpd.mutable_pattern()
|
||||
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
||||
->AsInput()
|
||||
->assert_is_op_input("conv2d", "Input");
|
||||
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
|
||||
conv_bias_pattern(conv_input);
|
||||
int found_conv_bias_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle ConvBias fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
|
||||
conv_bias_pattern); // Filter
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
|
||||
// bias
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
|
||||
// output
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
|
||||
// elementwise_add op
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
|
||||
|
||||
PADDLE_ENFORCE(subgraph.count(conv_input));
|
||||
|
||||
// check if fuse can be done and if MKL-DNN should be used
|
||||
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
|
||||
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
|
||||
VLOG(3) << "do not perform conv+bias fuse";
|
||||
return;
|
||||
}
|
||||
|
||||
auto* eltwise_bias_tensor =
|
||||
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
|
||||
|
||||
auto input_names = conv->Op()->InputNames();
|
||||
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
|
||||
input_names.end();
|
||||
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
|
||||
auto conv_bias_names = conv->Op()->Input("Bias");
|
||||
// add eltwise bias to existing conv bias
|
||||
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
|
||||
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
|
||||
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
|
||||
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
|
||||
*conv_bias_tensor = tensor_apply_eltwise(
|
||||
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
|
||||
|
||||
conv->Op()->SetOutput("Output",
|
||||
std::vector<std::string>({eltwise_out->Name()}));
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
|
||||
|
||||
IR_NODE_LINK_TO(conv, eltwise_out);
|
||||
} else {
|
||||
// take eltwise bias as conv bias
|
||||
OpDesc desc;
|
||||
|
||||
desc.SetInput(
|
||||
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
|
||||
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
|
||||
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
|
||||
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
|
||||
desc.SetType("conv2d");
|
||||
|
||||
for (auto& attr : conv->Op()->GetAttrMap()) {
|
||||
desc.SetAttr(attr.first, attr.second);
|
||||
}
|
||||
auto conv_bias_node = g->CreateOpNode(&desc);
|
||||
|
||||
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
|
||||
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
|
||||
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
|
||||
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
|
||||
}
|
||||
|
||||
found_conv_bias_count++;
|
||||
};
|
||||
gpd(graph.get(), handler);
|
||||
AddStatis(found_conv_bias_count);
|
||||
return graph;
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ConvBiasFusePass);
|
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
/*
|
||||
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
|
||||
*/
|
||||
class ConvBiasFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvBiasFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
|
||||
};
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,154 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_traits.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
// The function keeps the graph consistent by replacing
|
||||
// a node 'from' in the set of inputs nodes
|
||||
// of the visited node by a node 'to'.
|
||||
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
||||
for (auto& node : GraphTraits::DFS(*graph)) {
|
||||
auto from_in_inputs =
|
||||
std::find(std::begin(node.inputs), std::end(node.inputs), from);
|
||||
|
||||
if (from_in_inputs != std::end(node.inputs)) {
|
||||
IR_NODE_LINK_TO(to, (&node));
|
||||
|
||||
auto inputs = node.Op()->Inputs();
|
||||
|
||||
using input_type = VariableNameMap::value_type;
|
||||
|
||||
std::for_each(std::begin(inputs), std::end(inputs),
|
||||
[from, to, &node](const input_type& i) -> void {
|
||||
auto param_names = i.second;
|
||||
auto pi = std::find(std::begin(param_names),
|
||||
std::end(param_names), from->Name());
|
||||
|
||||
if (pi != std::end(param_names)) {
|
||||
node.Op()->SetInput(i.first, {to->Name()});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
using graph_ptr = std::unique_ptr<ir::Graph>;
|
||||
|
||||
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto pattern = gpd.mutable_pattern();
|
||||
|
||||
patterns::Conv conv_pattern{pattern, name_scope_};
|
||||
auto conv_output = conv_pattern();
|
||||
|
||||
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
|
||||
elementwise_add_pattern(conv_output);
|
||||
|
||||
conv_output->AsIntermediate();
|
||||
|
||||
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
|
||||
auto bias_input_names = conv_op.Op()->Inputs();
|
||||
auto bias_it = bias_input_names.find("Bias");
|
||||
|
||||
if (bias_it != std::end(bias_input_names)) {
|
||||
bool has_bias = !bias_it->second.empty();
|
||||
|
||||
if (has_bias) {
|
||||
auto conv_bias_names = bias_it->second;
|
||||
auto conv_bias_names_it =
|
||||
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
|
||||
[&conv_bias_names](Node* n) -> bool {
|
||||
return n->Name() == conv_bias_names[0];
|
||||
});
|
||||
return std::make_pair(has_bias, *conv_bias_names_it);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(false, nullptr);
|
||||
};
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
||||
elementwise_add_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
|
||||
elementwise_add_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
||||
elementwise_add_pattern);
|
||||
|
||||
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
|
||||
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("conv2d");
|
||||
|
||||
op_desc.SetInput("Input", {conv_input->Name()});
|
||||
op_desc.SetInput("Filter", {conv_filter->Name()});
|
||||
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
|
||||
op_desc.SetOutput("Output", {conv_output->Name()});
|
||||
|
||||
bool has_bias;
|
||||
Node* conv_bias;
|
||||
|
||||
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
|
||||
|
||||
if (has_bias) {
|
||||
op_desc.SetInput("Bias", {conv_bias->Name()});
|
||||
}
|
||||
|
||||
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
|
||||
op_desc.SetAttr(attr.first, attr.second);
|
||||
}
|
||||
|
||||
op_desc.SetAttr("fuse_residual_connection", true);
|
||||
|
||||
auto fused_conv_op = g->CreateOpNode(&op_desc);
|
||||
|
||||
IR_NODE_LINK_TO(conv_input, fused_conv_op);
|
||||
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
|
||||
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
|
||||
IR_NODE_LINK_TO(fused_conv_op, conv_output);
|
||||
|
||||
if (has_bias) {
|
||||
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
|
||||
}
|
||||
|
||||
CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
||||
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
|
||||
return graph;
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
|
||||
const std::string name_scope_{"residual_connections_fuse_pass"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,247 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
||||
#include "paddle/fluid/framework/ir/graph_traits.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
namespace {
|
||||
constexpr int nodes_removed = 3;
|
||||
constexpr int nodes_added = 1;
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type,
|
||||
const std::vector<std::pair<std::string, std::string>>& inputs,
|
||||
const std::pair<std::string, std::string>& output) {
|
||||
auto op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetAttr("use_mkldnn", true);
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
op->SetInput(input.first, {input.second});
|
||||
}
|
||||
|
||||
op->SetOutput(output.first, {output.second});
|
||||
}
|
||||
|
||||
struct IsReachable {
|
||||
using func = std::function<bool(const std::string&, const std::string&)>;
|
||||
|
||||
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
|
||||
auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
|
||||
const std::string& name) -> Node* {
|
||||
for (auto& node : GraphTraits::DFS(*graph)) {
|
||||
if (name == node.Name()) {
|
||||
return &node;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
return [&](std::string from, const std::string to) -> bool {
|
||||
if (from == to) return true;
|
||||
|
||||
std::map<std::string, bool> visited;
|
||||
|
||||
for (auto& node : GraphTraits::DFS(*graph)) {
|
||||
visited[node.Name()] = false;
|
||||
}
|
||||
|
||||
visited[from] = true;
|
||||
|
||||
std::list<std::string> queue;
|
||||
queue.push_back(from);
|
||||
|
||||
while (!queue.empty()) {
|
||||
auto cur = find_node(graph, queue.front());
|
||||
queue.pop_front();
|
||||
|
||||
if (cur == nullptr) return false;
|
||||
|
||||
for (auto n : cur->outputs) {
|
||||
if (n->Name() == to) return true;
|
||||
|
||||
if (!visited[n->Name()]) {
|
||||
visited[n->Name()] = true;
|
||||
queue.push_back(n->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
|
||||
int conv_count = 0;
|
||||
int elementwise_add_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == "conv2d") {
|
||||
++conv_count;
|
||||
}
|
||||
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
|
||||
++elementwise_add_count;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(conv_count, 1);
|
||||
EXPECT_EQ(elementwise_add_count, 0);
|
||||
}
|
||||
|
||||
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
|
||||
const std::vector<std::string>& persistent_vars) {
|
||||
ProgramDesc prog;
|
||||
|
||||
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
|
||||
auto var = prog.MutableBlock(0)->Var(var_name);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
|
||||
return var;
|
||||
};
|
||||
|
||||
for (const auto& v : transient_vars) {
|
||||
add_var_to_prog(v);
|
||||
}
|
||||
|
||||
for (const auto& v : persistent_vars) {
|
||||
auto var = add_var_to_prog(v);
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
|
||||
return prog;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
|
||||
auto prog =
|
||||
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
|
||||
|
||||
SetOp(&prog, "conv2d",
|
||||
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
||||
{"Output", "b"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
||||
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass,
|
||||
ConvolutionWithElementwiseAddReluNoBias) {
|
||||
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
|
||||
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
|
||||
{"Output", "b"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
||||
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
|
||||
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
|
||||
SetOp(&prog, "conv2d",
|
||||
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
||||
{"Output", "b"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "d"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_FALSE(is_reachable(graph)("a", "d"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
|
||||
auto prog =
|
||||
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
|
||||
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
||||
SetOp(&prog, "conv2d",
|
||||
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
||||
{"Output", "c"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
|
||||
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(conv_elementwise_add_mkldnn_fuse_pass);
|
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
|
||||
GraphPatternDetector gpd;
|
||||
auto* pattern = gpd.mutable_pattern();
|
||||
|
||||
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
|
||||
->assert_is_op_input("sequence_conv")
|
||||
->assert_var_not_persistable();
|
||||
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
|
||||
fuse_pattern(x);
|
||||
|
||||
// Create New OpDesc
|
||||
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
|
||||
Node* eltadd_bias, Node* relu_out) {
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("fusion_seqconv_eltadd_relu");
|
||||
op_desc.SetInput("X", {input->Name()});
|
||||
op_desc.SetInput("Filter", {seqconv_weight->Name()});
|
||||
op_desc.SetInput("Bias", {eltadd_bias->Name()});
|
||||
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
|
||||
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
|
||||
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
|
||||
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
|
||||
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
|
||||
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
|
||||
op_desc.SetOutput("ColMat", {ColMat});
|
||||
op_desc.SetOutput("Out", {relu_out->Name()});
|
||||
scope->Var(ColMat)->GetMutable<LoDTensor>();
|
||||
|
||||
auto* op = graph->CreateOpNode(&op_desc);
|
||||
IR_NODE_LINK_TO(input, op);
|
||||
IR_NODE_LINK_TO(seqconv_weight, op);
|
||||
IR_NODE_LINK_TO(eltadd_bias, op);
|
||||
IR_NODE_LINK_TO(op, relu_out);
|
||||
return op;
|
||||
};
|
||||
|
||||
int fusion_count{0};
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
|
||||
|
||||
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
|
||||
relu_out);
|
||||
std::unordered_set<const Node*> marked_nodes(
|
||||
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
|
||||
GraphSafeRemoveNodes(graph, marked_nodes);
|
||||
++fusion_count;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
|
||||
return fusion_count;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
|
||||
AddStatis(fusion_count);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
|
||||
paddle::framework::ir::SeqConvEltAddReluFusePass);
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class SeqConvEltAddReluFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~SeqConvEltAddReluFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
|
||||
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue