Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_dropout_att_new
commit
049c9c7d2a
@ -0,0 +1,137 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
template <typename BinaryOperation>
|
||||
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
|
||||
BinaryOperation f) {
|
||||
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
|
||||
LoDTensor vec_y;
|
||||
vec_y.Resize(vec_a.dims());
|
||||
const float* a = vec_a.data<float>();
|
||||
const float* b = vec_b.data<float>();
|
||||
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
|
||||
for (int i = 0; i < vec_a.numel(); i++) {
|
||||
y[i] = f(a[i], b[i]);
|
||||
}
|
||||
return vec_y;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
auto* scope = param_scope();
|
||||
PADDLE_ENFORCE(scope);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* conv_input =
|
||||
gpd.mutable_pattern()
|
||||
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
||||
->AsInput()
|
||||
->assert_is_op_input("conv2d", "Input");
|
||||
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
|
||||
conv_bias_pattern(conv_input);
|
||||
int found_conv_bias_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle ConvBias fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
|
||||
conv_bias_pattern); // Filter
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
|
||||
// bias
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
|
||||
// output
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
|
||||
// elementwise_add op
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
|
||||
|
||||
PADDLE_ENFORCE(subgraph.count(conv_input));
|
||||
|
||||
// check if fuse can be done and if MKL-DNN should be used
|
||||
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
|
||||
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
|
||||
VLOG(3) << "do not perform conv+bias fuse";
|
||||
return;
|
||||
}
|
||||
|
||||
auto* eltwise_bias_tensor =
|
||||
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
|
||||
|
||||
auto input_names = conv->Op()->InputNames();
|
||||
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
|
||||
input_names.end();
|
||||
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
|
||||
auto conv_bias_names = conv->Op()->Input("Bias");
|
||||
// add eltwise bias to existing conv bias
|
||||
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
|
||||
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
|
||||
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
|
||||
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
|
||||
*conv_bias_tensor = tensor_apply_eltwise(
|
||||
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
|
||||
|
||||
conv->Op()->SetOutput("Output",
|
||||
std::vector<std::string>({eltwise_out->Name()}));
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
|
||||
|
||||
IR_NODE_LINK_TO(conv, eltwise_out);
|
||||
} else {
|
||||
// take eltwise bias as conv bias
|
||||
OpDesc desc;
|
||||
|
||||
desc.SetInput(
|
||||
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
|
||||
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
|
||||
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
|
||||
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
|
||||
desc.SetType("conv2d");
|
||||
|
||||
for (auto& attr : conv->Op()->GetAttrMap()) {
|
||||
desc.SetAttr(attr.first, attr.second);
|
||||
}
|
||||
auto conv_bias_node = g->CreateOpNode(&desc);
|
||||
|
||||
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
|
||||
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
|
||||
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
|
||||
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
|
||||
|
||||
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
|
||||
}
|
||||
|
||||
found_conv_bias_count++;
|
||||
};
|
||||
gpd(graph.get(), handler);
|
||||
AddStatis(found_conv_bias_count);
|
||||
return graph;
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ConvBiasFusePass);
|
@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
/*
|
||||
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
|
||||
*/
|
||||
class ConvBiasFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvBiasFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
|
||||
};
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,154 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_traits.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
// The function keeps the graph consistent by replacing
|
||||
// a node 'from' in the set of inputs nodes
|
||||
// of the visited node by a node 'to'.
|
||||
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
||||
for (auto& node : GraphTraits::DFS(*graph)) {
|
||||
auto from_in_inputs =
|
||||
std::find(std::begin(node.inputs), std::end(node.inputs), from);
|
||||
|
||||
if (from_in_inputs != std::end(node.inputs)) {
|
||||
IR_NODE_LINK_TO(to, (&node));
|
||||
|
||||
auto inputs = node.Op()->Inputs();
|
||||
|
||||
using input_type = VariableNameMap::value_type;
|
||||
|
||||
std::for_each(std::begin(inputs), std::end(inputs),
|
||||
[from, to, &node](const input_type& i) -> void {
|
||||
auto param_names = i.second;
|
||||
auto pi = std::find(std::begin(param_names),
|
||||
std::end(param_names), from->Name());
|
||||
|
||||
if (pi != std::end(param_names)) {
|
||||
node.Op()->SetInput(i.first, {to->Name()});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
using graph_ptr = std::unique_ptr<ir::Graph>;
|
||||
|
||||
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto pattern = gpd.mutable_pattern();
|
||||
|
||||
patterns::Conv conv_pattern{pattern, name_scope_};
|
||||
auto conv_output = conv_pattern();
|
||||
|
||||
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
|
||||
elementwise_add_pattern(conv_output);
|
||||
|
||||
conv_output->AsIntermediate();
|
||||
|
||||
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
|
||||
auto bias_input_names = conv_op.Op()->Inputs();
|
||||
auto bias_it = bias_input_names.find("Bias");
|
||||
|
||||
if (bias_it != std::end(bias_input_names)) {
|
||||
bool has_bias = !bias_it->second.empty();
|
||||
|
||||
if (has_bias) {
|
||||
auto conv_bias_names = bias_it->second;
|
||||
auto conv_bias_names_it =
|
||||
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
|
||||
[&conv_bias_names](Node* n) -> bool {
|
||||
return n->Name() == conv_bias_names[0];
|
||||
});
|
||||
return std::make_pair(has_bias, *conv_bias_names_it);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(false, nullptr);
|
||||
};
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
||||
elementwise_add_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
|
||||
elementwise_add_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
||||
elementwise_add_pattern);
|
||||
|
||||
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
|
||||
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("conv2d");
|
||||
|
||||
op_desc.SetInput("Input", {conv_input->Name()});
|
||||
op_desc.SetInput("Filter", {conv_filter->Name()});
|
||||
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
|
||||
op_desc.SetOutput("Output", {conv_output->Name()});
|
||||
|
||||
bool has_bias;
|
||||
Node* conv_bias;
|
||||
|
||||
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
|
||||
|
||||
if (has_bias) {
|
||||
op_desc.SetInput("Bias", {conv_bias->Name()});
|
||||
}
|
||||
|
||||
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
|
||||
op_desc.SetAttr(attr.first, attr.second);
|
||||
}
|
||||
|
||||
op_desc.SetAttr("fuse_residual_connection", true);
|
||||
|
||||
auto fused_conv_op = g->CreateOpNode(&op_desc);
|
||||
|
||||
IR_NODE_LINK_TO(conv_input, fused_conv_op);
|
||||
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
|
||||
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
|
||||
IR_NODE_LINK_TO(fused_conv_op, conv_output);
|
||||
|
||||
if (has_bias) {
|
||||
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
|
||||
}
|
||||
|
||||
CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
||||
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
|
||||
return graph;
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
|
||||
const std::string name_scope_{"residual_connections_fuse_pass"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,247 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
||||
#include "paddle/fluid/framework/ir/graph_traits.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
namespace {
|
||||
constexpr int nodes_removed = 3;
|
||||
constexpr int nodes_added = 1;
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type,
|
||||
const std::vector<std::pair<std::string, std::string>>& inputs,
|
||||
const std::pair<std::string, std::string>& output) {
|
||||
auto op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetAttr("use_mkldnn", true);
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
op->SetInput(input.first, {input.second});
|
||||
}
|
||||
|
||||
op->SetOutput(output.first, {output.second});
|
||||
}
|
||||
|
||||
struct IsReachable {
|
||||
using func = std::function<bool(const std::string&, const std::string&)>;
|
||||
|
||||
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
|
||||
auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
|
||||
const std::string& name) -> Node* {
|
||||
for (auto& node : GraphTraits::DFS(*graph)) {
|
||||
if (name == node.Name()) {
|
||||
return &node;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
return [&](std::string from, const std::string to) -> bool {
|
||||
if (from == to) return true;
|
||||
|
||||
std::map<std::string, bool> visited;
|
||||
|
||||
for (auto& node : GraphTraits::DFS(*graph)) {
|
||||
visited[node.Name()] = false;
|
||||
}
|
||||
|
||||
visited[from] = true;
|
||||
|
||||
std::list<std::string> queue;
|
||||
queue.push_back(from);
|
||||
|
||||
while (!queue.empty()) {
|
||||
auto cur = find_node(graph, queue.front());
|
||||
queue.pop_front();
|
||||
|
||||
if (cur == nullptr) return false;
|
||||
|
||||
for (auto n : cur->outputs) {
|
||||
if (n->Name() == to) return true;
|
||||
|
||||
if (!visited[n->Name()]) {
|
||||
visited[n->Name()] = true;
|
||||
queue.push_back(n->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
|
||||
int conv_count = 0;
|
||||
int elementwise_add_count = 0;
|
||||
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == "conv2d") {
|
||||
++conv_count;
|
||||
}
|
||||
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
|
||||
++elementwise_add_count;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(conv_count, 1);
|
||||
EXPECT_EQ(elementwise_add_count, 0);
|
||||
}
|
||||
|
||||
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
|
||||
const std::vector<std::string>& persistent_vars) {
|
||||
ProgramDesc prog;
|
||||
|
||||
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
|
||||
auto var = prog.MutableBlock(0)->Var(var_name);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
|
||||
return var;
|
||||
};
|
||||
|
||||
for (const auto& v : transient_vars) {
|
||||
add_var_to_prog(v);
|
||||
}
|
||||
|
||||
for (const auto& v : persistent_vars) {
|
||||
auto var = add_var_to_prog(v);
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
|
||||
return prog;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
|
||||
auto prog =
|
||||
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
|
||||
|
||||
SetOp(&prog, "conv2d",
|
||||
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
||||
{"Output", "b"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
||||
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass,
|
||||
ConvolutionWithElementwiseAddReluNoBias) {
|
||||
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
|
||||
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
|
||||
{"Output", "b"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
||||
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
|
||||
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
|
||||
SetOp(&prog, "conv2d",
|
||||
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
||||
{"Output", "b"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "d"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_FALSE(is_reachable(graph)("a", "d"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
|
||||
auto prog =
|
||||
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
|
||||
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
||||
SetOp(&prog, "conv2d",
|
||||
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
||||
{"Output", "c"});
|
||||
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
|
||||
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
IsReachable is_reachable;
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
graph = pass->Apply(std::move(graph));
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
||||
|
||||
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
||||
current_nodes_num);
|
||||
AssertOpsCount(graph);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(conv_elementwise_add_mkldnn_fuse_pass);
|
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void FusePassBase::Init(const std::string& repr, Graph* graph) const {
|
||||
repr_ = repr;
|
||||
graph_ = graph;
|
||||
}
|
||||
|
||||
Scope* FusePassBase::param_scope() const {
|
||||
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
|
||||
return graph_->Get<framework::Scope*>(kParamScopeAttr);
|
||||
}
|
||||
|
||||
void FusePassBase::AddStatis(int count_of_fused) const {
|
||||
PADDLE_ENFORCE(graph_);
|
||||
PADDLE_ENFORCE(!repr_.empty());
|
||||
if (!graph_->Has(kFuseStatisAttr)) {
|
||||
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
|
||||
}
|
||||
auto& info =
|
||||
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
|
||||
info[repr_] = count_of_fused;
|
||||
}
|
||||
|
||||
FuseOptions FusePassBase::FindFuseOption(const Node& node1,
|
||||
const Node& node2) const {
|
||||
#ifdef PADDLE_WITH_MKLDNN
|
||||
bool node1_mkldnn = node1.Op()->HasAttr("use_mkldnn") &&
|
||||
boost::get<bool>(node1.Op()->GetAttr("use_mkldnn"));
|
||||
bool node2_mkldnn = node2.Op()->HasAttr("use_mkldnn") &&
|
||||
boost::get<bool>(node2.Op()->GetAttr("use_mkldnn"));
|
||||
if (node1_mkldnn && node2_mkldnn)
|
||||
return FUSE_MKLDNN;
|
||||
else if (!node1_mkldnn && !node2_mkldnn)
|
||||
return FUSE_NATIVE;
|
||||
else
|
||||
return DO_NOT_FUSE;
|
||||
#else
|
||||
return FUSE_NATIVE;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,37 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
VLOG(3) << "Aplies MKL-DNN placement strategy.";
|
||||
for (const Node* n : graph->Nodes()) {
|
||||
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) {
|
||||
n->Op()->SetAttr("use_mkldnn", true);
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(mkldnn_placement_pass,
|
||||
paddle::framework::ir::MKLDNNPlacementPass);
|
@ -0,0 +1,31 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class MKLDNNPlacementPass : public Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
|
||||
GraphPatternDetector gpd;
|
||||
auto* pattern = gpd.mutable_pattern();
|
||||
|
||||
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
|
||||
->assert_is_op_input("sequence_conv")
|
||||
->assert_var_not_persistable();
|
||||
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
|
||||
fuse_pattern(x);
|
||||
|
||||
// Create New OpDesc
|
||||
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
|
||||
Node* eltadd_bias, Node* relu_out) {
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("fusion_seqconv_eltadd_relu");
|
||||
op_desc.SetInput("X", {input->Name()});
|
||||
op_desc.SetInput("Filter", {seqconv_weight->Name()});
|
||||
op_desc.SetInput("Bias", {eltadd_bias->Name()});
|
||||
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
|
||||
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
|
||||
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
|
||||
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
|
||||
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
|
||||
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
|
||||
op_desc.SetOutput("ColMat", {ColMat});
|
||||
op_desc.SetOutput("Out", {relu_out->Name()});
|
||||
scope->Var(ColMat)->GetMutable<LoDTensor>();
|
||||
|
||||
auto* op = graph->CreateOpNode(&op_desc);
|
||||
IR_NODE_LINK_TO(input, op);
|
||||
IR_NODE_LINK_TO(seqconv_weight, op);
|
||||
IR_NODE_LINK_TO(eltadd_bias, op);
|
||||
IR_NODE_LINK_TO(op, relu_out);
|
||||
return op;
|
||||
};
|
||||
|
||||
int fusion_count{0};
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
|
||||
|
||||
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
|
||||
relu_out);
|
||||
std::unordered_set<const Node*> marked_nodes(
|
||||
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
|
||||
GraphSafeRemoveNodes(graph, marked_nodes);
|
||||
++fusion_count;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
|
||||
return fusion_count;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
FusePassBase::Init(name_scope_, graph.get());
|
||||
|
||||
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
|
||||
AddStatis(fusion_count);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
|
||||
paddle::framework::ir::SeqConvEltAddReluFusePass);
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class SeqConvEltAddReluFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~SeqConvEltAddReluFusePass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||
|
||||
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue