MKLDNN residual connections fuse pass:

* implements reachability check between identity node and non-identity argument to elementwise_add
* implements handling identity node as x and as y argument to elementwise_add
panyx0718-patch-1
Tomasz Patejko 7 years ago
parent 1722678258
commit 7423748e37

@ -15,6 +15,7 @@
#pragma once
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
@ -23,16 +24,105 @@ namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
using graph_ptr = std::unique_ptr<ir::Graph>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
using handler_func = std::function<void(
const GraphPatternDetector::subgraph_t& subgraph, Graph* g)>;
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
std::pair<bool, Node*> HasBias(const Node& op) const;
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
typename HANDLER_FUNC = handler_func>
HANDLER_FUNC GenerateFuseHandler(
const patterns::Conv& conv_pattern,
const patterns::ElementwiseAdd& elementwise_add_pattern,
CONV_FUNC get_node_from_conv_op,
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const;
public:
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
virtual ~ResidualConnectionMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const;
const std::string name_scope_{"residual_connections_fuse_pass"};
const std::string name_scope_{"residual_connection_fuse_pass"};
};
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
typename HANDLER_FUNC>
HANDLER_FUNC ResidualConnectionMKLDNNFusePass::GenerateFuseHandler(
const patterns::Conv& conv_pattern,
const patterns::ElementwiseAdd& elementwise_add_pattern,
CONV_FUNC get_node_from_conv_op,
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const {
return [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_op;
Node* conv_input;
Node* conv_filter;
Node* conv_output;
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(conv_pattern, subgraph);
std::tie(elementwise_add_op, elementwise_add_identity,
elementwise_add_out) =
get_node_from_elementwise_add_op(elementwise_add_pattern, subgraph);
if (this->FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN)
return;
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = this->HasBias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(graph,
{elementwise_add_out, conv_op, elementwise_add_op});
};
}
} // namespace ir
} // namespace framework
} // namespace paddle

@ -1084,16 +1084,12 @@ PDNode *patterns::Conv::operator()() {
return output_var;
}
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) {
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
x_var->assert_is_op_input("elementwise_add", "X");
auto y_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
x_var->AsInput()->assert_is_op_input("elementwise_add", "X");
y_var->AsInput()->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");

@ -664,7 +664,7 @@ struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var);
PDNode* operator()(PDNode* x_var, PDNode* y_var);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x);

Loading…
Cancel
Save