|
|
|
@ -99,10 +99,9 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
|
|
|
|
|
const Node& op) const {
|
|
|
|
|
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
|
|
|
|
|
auto bias_input_names = op.Op()->Inputs();
|
|
|
|
|
auto bias_it = bias_input_names.find("Bias");
|
|
|
|
|
auto bias_it = bias_input_names.find(bias_name);
|
|
|
|
|
|
|
|
|
|
if (bias_it != std::end(bias_input_names)) {
|
|
|
|
|
bool has_bias = !bias_it->second.empty();
|
|
|
|
@ -121,6 +120,74 @@ std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
|
|
|
|
|
return std::make_pair(false, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
|
|
|
|
|
get_node_from_elementwise_add_op,
|
|
|
|
|
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
|
|
|
|
|
: get_node_from_conv_op{get_node_from_conv_op},
|
|
|
|
|
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
|
|
|
|
|
can_fuse_func{can_fuse_func} {}
|
|
|
|
|
|
|
|
|
|
void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
|
|
|
|
|
Node* conv_op;
|
|
|
|
|
Node* conv_input;
|
|
|
|
|
Node* conv_filter;
|
|
|
|
|
Node* conv_output;
|
|
|
|
|
|
|
|
|
|
Node* elementwise_add_op;
|
|
|
|
|
Node* elementwise_add_identity;
|
|
|
|
|
Node* elementwise_add_out;
|
|
|
|
|
|
|
|
|
|
std::tie(conv_op, conv_input, conv_filter, conv_output) =
|
|
|
|
|
get_node_from_conv_op(subgraph);
|
|
|
|
|
std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) =
|
|
|
|
|
get_node_from_elementwise_add_op(subgraph);
|
|
|
|
|
|
|
|
|
|
if (!can_fuse_func(conv_op, elementwise_add_op)) return;
|
|
|
|
|
|
|
|
|
|
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
bool has_bias;
|
|
|
|
|
Node* conv_bias;
|
|
|
|
|
|
|
|
|
|
std::tie(has_bias, conv_bias) = HasBias(*conv_op, "Bias");
|
|
|
|
|
|
|
|
|
|
if (has_bias) {
|
|
|
|
|
op_desc.SetInput("Bias", {conv_bias->Name()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
|
|
|
|
|
op_desc.SetAttr(attr.first, attr.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("fuse_residual_connection", true);
|
|
|
|
|
|
|
|
|
|
auto fused_conv_op = graph->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(conv_input, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(fused_conv_op, conv_output);
|
|
|
|
|
|
|
|
|
|
if (has_bias) {
|
|
|
|
|
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(graph,
|
|
|
|
|
{elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
|
const std::string& name_scope_, graph_ptr graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
@ -135,8 +202,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
|
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
auto get_node_from_conv =
|
|
|
|
|
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
|
|
|
|
@ -146,8 +213,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
|
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [](
|
|
|
|
|
const patterns::ElementwiseAdd& elementwise_add_pattern,
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
@ -161,10 +227,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler =
|
|
|
|
|
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
|
|
|
|
|
get_node_from_conv, get_node_from_elementwise_add);
|
|
|
|
|
gpd(graph.get(), handler);
|
|
|
|
|
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
|
|
|
|
|
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto fuse_handler =
|
|
|
|
|
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
|
|
|
|
|
|
|
|
|
|
gpd(graph.get(), fuse_handler);
|
|
|
|
|
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
@ -183,8 +253,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
|
conv_output);
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
auto get_node_from_conv =
|
|
|
|
|
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
|
|
|
|
@ -194,8 +264,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
|
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto get_node_from_elementwise_add = [](
|
|
|
|
|
const patterns::ElementwiseAdd& elementwise_add_pattern,
|
|
|
|
|
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
|
|
|
|
|
const GraphPatternDetector::subgraph_t& subgraph)
|
|
|
|
|
-> std::tuple<Node*, Node*, Node*> {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
@ -209,10 +278,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
|
|
|
|
|
elementwise_add_out);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler =
|
|
|
|
|
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
|
|
|
|
|
get_node_from_conv, get_node_from_elementwise_add);
|
|
|
|
|
gpd(graph.get(), handler);
|
|
|
|
|
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
|
|
|
|
|
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto fuse_handler =
|
|
|
|
|
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
|
|
|
|
|
|
|
|
|
|
gpd(graph.get(), fuse_handler);
|
|
|
|
|
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|