|
|
|
@ -47,23 +47,24 @@ struct Pattern : public PatternBase {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct Conv {
|
|
|
|
|
std::string op_name() { return "conv2d"; }
|
|
|
|
|
std::string input_name() { return "Input"; }
|
|
|
|
|
std::string filter_name() { return "Filter"; }
|
|
|
|
|
std::string output_name() { return "Output"; }
|
|
|
|
|
std::string op_name() const { return "conv2d"; }
|
|
|
|
|
std::string input_name() const { return "Input"; }
|
|
|
|
|
std::string filter_name() const { return "Filter"; }
|
|
|
|
|
std::string residual_data_name() const { return "ResidualData"; }
|
|
|
|
|
std::string output_name() const { return "Output"; }
|
|
|
|
|
|
|
|
|
|
std::function<PDNode*()> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&]() -> PDNode* {
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())->assert_is_op("conv2d");
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())->assert_is_op(op_name());
|
|
|
|
|
|
|
|
|
|
auto input_var = pattern->new_node(input_name())
|
|
|
|
|
->assert_is_op_input(op_name(), input_name());
|
|
|
|
|
->assert_is_op_input(op_name(), input_name());
|
|
|
|
|
|
|
|
|
|
auto filter_var = pattern->new_node(filter_name())
|
|
|
|
|
->assert_is_op_input(op_name(), filter_name());
|
|
|
|
|
->assert_is_op_input(op_name(), filter_name());
|
|
|
|
|
|
|
|
|
|
auto output_var = pattern->new_node(output_name())
|
|
|
|
|
->assert_is_op_output(op_name(), output_name());
|
|
|
|
|
->assert_is_op_output(op_name(), output_name());
|
|
|
|
|
|
|
|
|
|
conv_op->LinksFrom({input_var, filter_var});
|
|
|
|
|
conv_op->LinksTo({output_var});
|
|
|
|
@ -74,15 +75,15 @@ struct Conv {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ElementwiseAdd {
|
|
|
|
|
std::string op_name() { return "elementwise_add"; }
|
|
|
|
|
std::string x_name() { return "X"; }
|
|
|
|
|
std::string y_name() { return "Y"; }
|
|
|
|
|
std::string out_name() { return "Out"; }
|
|
|
|
|
std::string op_name() const { return "elementwise_add"; }
|
|
|
|
|
std::string x_name() const { return "X"; }
|
|
|
|
|
std::string y_name() const { return "Y"; }
|
|
|
|
|
std::string out_name() const { return "Out"; }
|
|
|
|
|
|
|
|
|
|
std::function<PDNode*(PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&](PDNode* conv_output) -> PDNode* {
|
|
|
|
|
auto elementwise_add_op =
|
|
|
|
|
pattern->new_node(op_name())->assert_is_op("elementwise_add");
|
|
|
|
|
pattern->new_node(op_name())->assert_is_op(op_name());
|
|
|
|
|
|
|
|
|
|
auto x_var =
|
|
|
|
|
pattern->new_node(x_name())->assert_is_op_input(op_name(), x_name());
|
|
|
|
@ -90,8 +91,8 @@ struct ElementwiseAdd {
|
|
|
|
|
conv_output->assert_is_op_input(op_name(), y_name());
|
|
|
|
|
|
|
|
|
|
auto out_var = pattern->new_node(out_name())
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output(op_name(), out_name());
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output(op_name(), out_name());
|
|
|
|
|
|
|
|
|
|
elementwise_add_op->LinksFrom({x_var, conv_output});
|
|
|
|
|
elementwise_add_op->LinksTo({out_var});
|
|
|
|
@ -177,15 +178,17 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter,
|
|
|
|
|
Node* conv_output, Node* elementwise_add_x) {
|
|
|
|
|
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input,
|
|
|
|
|
Node* conv_filter,
|
|
|
|
|
Node* conv_output,
|
|
|
|
|
Node* elementwise_add_x) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
op_desc.SetType(conv_pattern.op_name());
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {conv_output->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput(conv_pattern.residual_data_name(), {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput(conv_pattern.output_name(), {conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
op_desc.SetAttr("use_mkldnn", true);
|
|
|
|
|
op_desc.SetAttr("fuse_eltwise", true);
|
|
|
|
@ -198,8 +201,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, conv_output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
auto handler = [&conv_pattern, &elementwise_add_pattern, pattern_ptr, fuse_conv]
|
|
|
|
|
(const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
|
|
|
|
|
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.op_name());
|
|
|
|
|
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|