added reshape transpose matmul fuse pass (#23754)
parent
61d19a8e1c
commit
e1a7a88057
@ -0,0 +1,119 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/string/pretty_log.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
|
||||
Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const {
|
||||
GraphPatternDetector gpd;
|
||||
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(),
|
||||
name_scope_);
|
||||
|
||||
rtm_pattern(with_reshape_xshape, with_transpose_xshape);
|
||||
|
||||
int found_reshape_transpose_matmul_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
||||
Graph *g) {
|
||||
VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
|
||||
ir::Node *reshape_xshape{nullptr};
|
||||
if (with_reshape_xshape) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(reshape_xshape1, reshape_xshape, rtm_pattern);
|
||||
reshape_xshape = reshape_xshape1;
|
||||
}
|
||||
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, rtm_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, rtm_pattern);
|
||||
ir::Node *transpose_xshape{nullptr};
|
||||
if (with_transpose_xshape) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(transpose_xshape1, transpose_xshape,
|
||||
rtm_pattern);
|
||||
transpose_xshape = transpose_xshape1;
|
||||
}
|
||||
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, rtm_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern);
|
||||
|
||||
auto reshape_shape =
|
||||
boost::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
|
||||
auto transpose_axis =
|
||||
boost::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
|
||||
|
||||
OpDesc *matmul_desc = matmul_op->Op();
|
||||
std::string input_var_name = transpose_out->Name();
|
||||
|
||||
auto UpdateMatmul = [&](std::string matmul_input_name) {
|
||||
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
|
||||
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
|
||||
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
|
||||
transpose_axis);
|
||||
};
|
||||
if (matmul_desc->Inputs().at("X").at(0) == input_var_name) {
|
||||
UpdateMatmul("X");
|
||||
} else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) {
|
||||
UpdateMatmul("Y");
|
||||
} else {
|
||||
throw platform::errors::InvalidArgument(
|
||||
"Unexpected input to MatMul encountered.");
|
||||
}
|
||||
|
||||
std::unordered_set<const ir::Node *> nodes_to_remove{
|
||||
reshape_op, reshape_out, transpose_op, transpose_out};
|
||||
if (with_reshape_xshape) nodes_to_remove.insert(reshape_xshape);
|
||||
if (with_transpose_xshape) nodes_to_remove.insert(transpose_xshape);
|
||||
GraphSafeRemoveNodes(graph, nodes_to_remove);
|
||||
|
||||
IR_NODE_LINK_TO(reshape_in, matmul_op);
|
||||
|
||||
++found_reshape_transpose_matmul_count;
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
AddStatis(found_reshape_transpose_matmul_count);
|
||||
|
||||
std::stringstream msg_ss;
|
||||
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count
|
||||
<< " ReshapeTransposeMatmulMkldnn patterns";
|
||||
if (with_reshape_xshape) msg_ss << " with reshape's xshape";
|
||||
if (with_transpose_xshape) msg_ss << " with transpose's xshape";
|
||||
string::PrettyLogDetail(msg_ss.str().c_str());
|
||||
}
|
||||
|
||||
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(ir::Graph *graph) const {
|
||||
PADDLE_ENFORCE_NOT_NULL(graph,
|
||||
platform::errors::InvalidArgument(
|
||||
"Pointer to graph argument should not be NULL."));
|
||||
FusePassBase::Init(name_scope_, graph);
|
||||
|
||||
Fuse(graph, false, false);
|
||||
Fuse(graph, false, true);
|
||||
Fuse(graph, true, false);
|
||||
Fuse(graph, true, true);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(reshape_transpose_matmul_mkldnn_fuse_pass,
|
||||
paddle::framework::ir::ReshapeTransposeMatmulMkldnnFusePass);
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
/*
|
||||
* Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn.
|
||||
*/
|
||||
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
const std::string name_scope_{"reshape_transpose_matmul_fuse"};
|
||||
|
||||
void Fuse(Graph* graph, bool with_reshape_xshape,
|
||||
bool with_transpose_xshape) const;
|
||||
};
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,124 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void AddVarToScope(Scope* param_scope, const std::string& name,
|
||||
const DDim& dims) {
|
||||
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
|
||||
tensor->Resize(dims);
|
||||
tensor->mutable_data<float>(platform::CPUPlace());
|
||||
}
|
||||
|
||||
Scope* CreateParamScope() {
|
||||
auto param_scope = new Scope();
|
||||
AddVarToScope(param_scope, "w1", {768, 768});
|
||||
AddVarToScope(param_scope, "bias1", {768});
|
||||
AddVarToScope(param_scope, "w2", {768, 768});
|
||||
AddVarToScope(param_scope, "bias2", {768});
|
||||
return param_scope;
|
||||
}
|
||||
|
||||
void TestMain(bool with_xshapes) {
|
||||
// inputs operator output
|
||||
// -----------------------------------------------
|
||||
// a1,w1,bias1 fc -> b1
|
||||
// b1 reshape -> c1
|
||||
// c1 transpose -> d1
|
||||
// a2,w2,bias2 fc -> b2
|
||||
// b2 reshape -> c2
|
||||
// c2 transpose -> d2
|
||||
// (d1, d2) matmul -> (...)
|
||||
Layers layers;
|
||||
auto* a1 = layers.data("a1", {-1, 128, 768});
|
||||
auto* w1 = layers.data("w1", {768, 768}, true);
|
||||
auto* bias1 = layers.data("bias1", {768}, true);
|
||||
auto* b1 = layers.fc(a1, w1, bias1, 2);
|
||||
b1->SetShape({-1, 128, 768});
|
||||
auto* c1 = layers.reshape2(b1, {0, 0, 12, 64}, with_xshapes);
|
||||
c1->SetShape({-1, 128, 12, 64});
|
||||
auto* d1 = layers.transpose2(c1, {0, 2, 1, 3}, with_xshapes);
|
||||
d1->SetShape({-1, 12, 128, 64});
|
||||
auto* a2 = layers.data("a2", {-1, 128, 768});
|
||||
auto* w2 = layers.data("w2", {768, 768}, true);
|
||||
auto* bias2 = layers.data("bias2", {768}, true);
|
||||
auto* b2 = layers.fc(a2, w2, bias2, 2);
|
||||
b2->SetShape({-1, 128, 768});
|
||||
auto* c2 = layers.reshape2(b2, {0, 0, 12, 64});
|
||||
c2->SetShape({-1, 128, 12, 64});
|
||||
auto* d2 = layers.transpose2(c2, {0, 2, 1, 3});
|
||||
d2->SetShape({-1, 12, 128, 64});
|
||||
layers.matmul(d1, d2);
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
|
||||
graph->Set("__param_scope__", CreateParamScope());
|
||||
|
||||
int num_reshape_nodes_before = GetNumOpNodes(graph, "reshape2");
|
||||
int num_transpose_nodes_before = GetNumOpNodes(graph, "transpose2");
|
||||
int total_nodes_before = graph->Nodes().size();
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
|
||||
int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
|
||||
int num_transpose_nodes_after = GetNumOpNodes(graph, "transpose2");
|
||||
int total_nodes_after = graph->Nodes().size();
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
EXPECT_EQ(num_reshape_nodes_before, 2);
|
||||
EXPECT_EQ(num_reshape_nodes_after, 0);
|
||||
EXPECT_EQ(num_transpose_nodes_before, 2);
|
||||
EXPECT_EQ(num_transpose_nodes_after, 0);
|
||||
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
|
||||
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
|
||||
EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
|
||||
auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op();
|
||||
|
||||
auto check = [&matmul_op_desc](std::string a) {
|
||||
std::string shape_str = "fused_reshape_" + a;
|
||||
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(shape_str),
|
||||
testing::ElementsAre(0, 0, 12, 64));
|
||||
std::string axis_str = "fused_transpose_" + a;
|
||||
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(axis_str),
|
||||
testing::ElementsAre(0, 2, 1, 3));
|
||||
};
|
||||
check("X");
|
||||
check("Y");
|
||||
}
|
||||
|
||||
TEST(ReshapeTransposeMatmulMkldnnFusePass,
|
||||
both_matmul_inputs_reshape_transpose) {
|
||||
TestMain(false);
|
||||
}
|
||||
|
||||
TEST(ReshapeTransposeMatmulMkldnnFusePass,
|
||||
both_matmul_inputs_reshape_transpose_one_with_xshapes) {
|
||||
TestMain(true);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
|
Loading…
Reference in new issue