added fusing matmul-transpose-reshape pass (#23866)
parent
46f3139c7f
commit
d31a174f51
@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
|
||||
#include <paddle/fluid/string/pretty_log.h>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
|
||||
PADDLE_ENFORCE_NOT_NULL(graph,
|
||||
platform::errors::InvalidArgument(
|
||||
"Pointer to graph argument should not be NULL."));
|
||||
FusePassBase::Init(name_scope_, graph);
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(),
|
||||
name_scope_);
|
||||
|
||||
mtrp();
|
||||
|
||||
int found_matmul_transpose_reshape_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
||||
Graph *g) {
|
||||
VLOG(4) << "handle matmul_transpose_reshape fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, mtrp);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(transpose_out_xshape, transpose_out_xshape, mtrp);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, mtrp);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, mtrp);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(reshape_out_xshape, reshape_out_xshape, mtrp);
|
||||
auto reshape_shape =
|
||||
boost::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
|
||||
auto transpose_axis =
|
||||
boost::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
|
||||
|
||||
auto reshape_out_size = reshape_shape.size();
|
||||
auto transpose_out_size = transpose_axis.size();
|
||||
const std::vector<int> supported_axis{0, 2, 1, 3};
|
||||
const bool supported_transpose_axis = std::equal(
|
||||
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
|
||||
if (transpose_out_size != 4) {
|
||||
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
|
||||
<< "supported rank is 4, received " << transpose_out_size;
|
||||
return;
|
||||
}
|
||||
if (!supported_transpose_axis) {
|
||||
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
|
||||
<< "supported transpose axis for the fuse are {0, 2, 1, 3}";
|
||||
return;
|
||||
}
|
||||
if (reshape_out_size != 3) {
|
||||
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
|
||||
<< "reshape_out supported rank is 3, received "
|
||||
<< reshape_out_size;
|
||||
return;
|
||||
}
|
||||
OpDesc *matmul_desc = matmul_op->Op();
|
||||
matmul_desc->SetOutput("Out", {reshape_out->Name()});
|
||||
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
|
||||
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
|
||||
|
||||
GraphSafeRemoveNodes(graph,
|
||||
{matmul_out, transpose_op, transpose_out, reshape_op,
|
||||
transpose_out_xshape, reshape_out_xshape});
|
||||
|
||||
IR_OP_VAR_LINK(matmul_op, reshape_out);
|
||||
|
||||
found_matmul_transpose_reshape_count++;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
AddStatis(found_matmul_transpose_reshape_count);
|
||||
std::stringstream msg_ss;
|
||||
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count
|
||||
<< " MatmulTransposeReshape patterns";
|
||||
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(matmul_transpose_reshape_fuse_pass,
|
||||
paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass);
|
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~MatmulTransposeReshapeMKLDNNPass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(Graph* graph) const override;
|
||||
const std::string name_scope_{"matmul_transpose_reshape_fuse"};
|
||||
};
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc *prog, const std::string &type,
|
||||
const std::vector<std::string> &inputs,
|
||||
const std::vector<std::string> &outputs) {
|
||||
auto *op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetInput("X", {inputs[0]});
|
||||
op->SetOutput("Out", {outputs[0]});
|
||||
if (type == "transpose2") {
|
||||
op->SetAttr("axis", std::vector<int>({0, 2, 1, 3}));
|
||||
op->SetOutput("XShape", {outputs[1]});
|
||||
}
|
||||
if (type == "reshape2") {
|
||||
op->SetAttr("shape", std::vector<int>({4, 5, 6}));
|
||||
op->SetOutput("XShape", {outputs[1]});
|
||||
}
|
||||
|
||||
if (type == "matmul") {
|
||||
op->SetInput("Y", {inputs[1]});
|
||||
op->SetAttr("use_mkldnn", true);
|
||||
}
|
||||
}
|
||||
|
||||
ProgramDesc BuildProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
for (auto &v : std::initializer_list<std::string>(
|
||||
{"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) {
|
||||
auto *var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::SELECTED_ROWS);
|
||||
}
|
||||
|
||||
SetOp(&prog, "matmul", {"a1", "a2"}, {"b"});
|
||||
SetOp(&prog, "transpose2", {"b"}, {"c", "cx"});
|
||||
SetOp(&prog, "reshape2", {"c"}, {"d", "dx"});
|
||||
SetOp(&prog, "fc", {"d"}, {"e"});
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
void MainTest(const ProgramDesc &prog) {
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass");
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
EXPECT_EQ(original_nodes_num - 6, current_nodes_num);
|
||||
|
||||
for (auto *node : graph->Nodes()) {
|
||||
if (node->IsOp()) {
|
||||
auto *op = node->Op();
|
||||
if (op->Type() == "matmul") {
|
||||
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"),
|
||||
std::vector<int>({4, 5, 6}));
|
||||
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"),
|
||||
std::vector<int>({0, 2, 1, 3}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MatmulTransposeReshapeFusePass, matmul_inputs) {
|
||||
auto prog = BuildProgramDesc();
|
||||
MainTest(prog);
|
||||
}
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(matmul_transpose_reshape_fuse_pass);
|
@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from inference_pass_test import InferencePassTest
|
||||
|
||||
|
||||
class TestMKLDNNMatmulFuseOp(InferencePassTest):
|
||||
def init_data(self):
|
||||
self.bs = 8
|
||||
self.d_type = np.float32
|
||||
self.shape_x = [12, 128, 128]
|
||||
self.shape_y = [12, 128, 64]
|
||||
self.enable_mkldnn = True
|
||||
|
||||
def make_network(self):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
x = fluid.data(
|
||||
name='x', shape=[-1] + self.shape_x, dtype=self.d_type)
|
||||
y = fluid.data(
|
||||
name='y', shape=[-1] + self.shape_y, dtype=self.d_type)
|
||||
out = fluid.layers.matmul(x, y)
|
||||
out = fluid.layers.transpose(out, perm=[0, 2, 1, 3])
|
||||
out = fluid.layers.reshape(
|
||||
out, [0, 0, self.shape_y[0] * self.shape_y[2]])
|
||||
out = fluid.layers.fc(out, size=1)
|
||||
return out
|
||||
|
||||
def setUp(self):
|
||||
self.init_data()
|
||||
out = self.make_network()
|
||||
self.set_feeds(out)
|
||||
|
||||
def set_feeds(self, out):
|
||||
self.feeds = {
|
||||
"x": np.random.random([self.bs] + self.shape_x).astype(self.d_type),
|
||||
"y": np.random.random([self.bs] + self.shape_y).astype(self.d_type)
|
||||
}
|
||||
self.fetch_list = [out]
|
||||
|
||||
def test_check_output(self):
|
||||
use_gpu = False
|
||||
self.check_output_with_option(use_gpu)
|
||||
|
||||
|
||||
class TestMKLDNNMatmulOtherDimsFuseOp(TestMKLDNNMatmulFuseOp):
|
||||
def init_data(self):
|
||||
self.bs = 8
|
||||
self.d_type = np.float32
|
||||
self.shape_x = [12, 1, 1]
|
||||
self.shape_y = [12, 1, 64]
|
||||
self.enable_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNMatmulOpNotFusedWrongTransposeAxis(TestMKLDNNMatmulFuseOp):
|
||||
def make_network(self):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
x = fluid.data(
|
||||
name='x', shape=[-1] + self.shape_x, dtype=self.d_type)
|
||||
y = fluid.data(
|
||||
name='y', shape=[-1] + self.shape_y, dtype=self.d_type)
|
||||
out = fluid.layers.matmul(x, y)
|
||||
out = fluid.layers.transpose(out, perm=[0, 1, 2, 3])
|
||||
out = fluid.layers.reshape(out, [0, 0, 0, 0])
|
||||
out = fluid.layers.fc(out, size=1)
|
||||
return out
|
||||
|
||||
|
||||
class TestMKLDNNMatmulOpNotFusedBreakPattern(TestMKLDNNMatmulFuseOp):
|
||||
def init_data(self):
|
||||
self.bs = 7
|
||||
self.d_type = np.float32
|
||||
self.shape_x = [12, 128, 128]
|
||||
self.shape_y = [12, 128, 64]
|
||||
self.enable_mkldnn = True
|
||||
|
||||
def make_network(self):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
x = fluid.data(
|
||||
name='x', shape=[-1] + self.shape_x, dtype=self.d_type)
|
||||
y = fluid.data(
|
||||
name='y', shape=[-1] + self.shape_y, dtype=self.d_type)
|
||||
out = fluid.layers.matmul(x, y)
|
||||
out = fluid.layers.transpose(out, perm=[0, 2, 1, 3])
|
||||
out = fluid.layers.transpose(
|
||||
out, perm=[0, 1, 2, 3]) # breaks pattern
|
||||
out = fluid.layers.reshape(
|
||||
out, [0, 0, self.shape_y[0] * self.shape_y[2]])
|
||||
out = fluid.layers.fc(out, size=1)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue