[Inference] Solve 2.0 trt performance reduce compare 1.8. (#29925)
parent
913f77a4b7
commit
2b1d796cd0
@ -0,0 +1,61 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/framework/op_version_registry.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
|
||||
std::string name_scope = "adaptive_pool2d_convert_global_pass";
|
||||
FusePassBase::Init(name_scope, graph);
|
||||
int num = 0;
|
||||
for (const Node* n : graph->Nodes()) {
|
||||
if (n->IsOp()) {
|
||||
auto* op = n->Op();
|
||||
if (op->HasAttr("adaptive") && op->HasAttr("ksize")) {
|
||||
bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive"));
|
||||
std::vector<int> ksize =
|
||||
BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize"));
|
||||
if (adaptive && ksize.size() == 2 && ksize[0] == 1 && ksize[1] == 1) {
|
||||
op->SetAttr("adaptive", false);
|
||||
op->SetAttr("global_pooling", true);
|
||||
++num;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// LOG(INFO) << "--- processed " << num << " nodes";
|
||||
AddStatis(num);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(adaptive_pool2d_convert_global_pass,
|
||||
paddle::framework::ir::AdaptivePool2dConvertGlobalPass);
|
||||
|
||||
REGISTER_PASS_CAPABILITY(adaptive_pool2d_convert_global_pass)
|
||||
.AddCombination(
|
||||
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
|
||||
"pool2d", 0));
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class Graph;
|
||||
|
||||
/*
|
||||
* Update pool2d's attr to speed up trt engine.
|
||||
*
|
||||
* when adaptive=true, ksize=[1,1], we turn to adaptive=false,
|
||||
* global_pooling=true.
|
||||
*/
|
||||
class AdaptivePool2dConvertGlobalPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~AdaptivePool2dConvertGlobalPass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,67 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
#include "paddle/fluid/framework/op_version_registry.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
TEST(AdaptivePool2dConvertGlobalPass, basic) {
|
||||
Layers layers;
|
||||
auto* x = layers.data("x", {1, 92, 28, 28});
|
||||
AttributeMap attrs;
|
||||
attrs["adaptive"] = true;
|
||||
attrs["ksize"] = std::vector<int>{1, 1};
|
||||
layers.pool2d(x, false, &attrs);
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
|
||||
auto pass =
|
||||
PassRegistry::Instance().Get("adaptive_pool2d_convert_global_pass");
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
bool global_pooling = false;
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == "pool2d") {
|
||||
if (node->Op()->HasAttr("global_pooling")) {
|
||||
global_pooling =
|
||||
BOOST_GET_CONST(bool, node->Op()->GetAttr("global_pooling"));
|
||||
}
|
||||
}
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(
|
||||
global_pooling, true,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The attribute of pool2d global_pooling should be true after fuse"));
|
||||
}
|
||||
|
||||
TEST(AdaptivePool2dConvertGlobalPass, pass_op_version_check) {
|
||||
ASSERT_TRUE(
|
||||
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
|
||||
.IsPassCompatible("adaptive_pool2d_convert_global_pass"));
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(adaptive_pool2d_convert_global_pass);
|
@ -0,0 +1,134 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/op_version_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace patterns {
|
||||
|
||||
struct UnsqueezeEltwise : public PatternBase {
|
||||
UnsqueezeEltwise(PDPattern *pattern, const std::string &name_scope)
|
||||
: PatternBase(pattern, name_scope, "unsqueeze2_eltwise_fuse_pass") {}
|
||||
|
||||
PDNode *operator()(PDNode *x, PDNode *y);
|
||||
|
||||
// declare operator node's name
|
||||
PATTERN_DECL_NODE(unsqz);
|
||||
PATTERN_DECL_NODE(elementwise);
|
||||
// declare variable node's name
|
||||
PATTERN_DECL_NODE(eltwise_in_x);
|
||||
PATTERN_DECL_NODE(unsqz_in);
|
||||
PATTERN_DECL_NODE(unsqz_out);
|
||||
PATTERN_DECL_NODE(eltwise_out);
|
||||
};
|
||||
|
||||
PDNode *UnsqueezeEltwise::operator()(PDNode *x, PDNode *y) {
|
||||
x->assert_is_op_input("elementwise_mul", "X");
|
||||
y->assert_is_op_input("unsqueeze2", "X");
|
||||
|
||||
auto *unsqz = pattern->NewNode(unsqz_repr())->assert_is_op("unsqueeze2");
|
||||
auto *unsqz_out = pattern->NewNode(unsqz_out_repr())
|
||||
->assert_is_op_output("unsqueeze2", "Out")
|
||||
->assert_is_op_input("elementwise_mul", "Y");
|
||||
unsqz->LinksFrom({y}).LinksTo({unsqz_out});
|
||||
|
||||
auto *elementwise =
|
||||
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_mul");
|
||||
auto *eltwise_out = pattern->NewNode(eltwise_out_repr())
|
||||
->AsOutput()
|
||||
->assert_is_op_output("elementwise_mul");
|
||||
|
||||
elementwise->LinksFrom({x, unsqz_out}).LinksTo({eltwise_out});
|
||||
return eltwise_out;
|
||||
}
|
||||
|
||||
} // namespace patterns
|
||||
|
||||
void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const {
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
graph, platform::errors::PreconditionNotMet("graph should not be null."));
|
||||
FusePassBase::Init("unsqueeze2_eltwise_fuse_pass", graph);
|
||||
int found_subgraph_count = 0;
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto *x = gpd.mutable_pattern()
|
||||
->NewNode("unsqueeze2_eltwise_fuse_pass/x")
|
||||
->AsInput()
|
||||
->assert_is_op_input("elementwise_mul", "X")
|
||||
->assert_var_not_persistable();
|
||||
auto *y = gpd.mutable_pattern()
|
||||
->NewNode("unsqueeze2_eltwise_fuse_pass/y")
|
||||
->AsInput()
|
||||
->assert_is_op_input("unsqueeze2", "X")
|
||||
->assert_var_not_persistable();
|
||||
patterns::UnsqueezeEltwise fused_pattern(gpd.mutable_pattern(),
|
||||
"unsqueeze2_eltwise_fuse_pass");
|
||||
fused_pattern(x, y);
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
||||
Graph *graph) {
|
||||
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
|
||||
LOG(WARNING) << "The subgraph is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
VLOG(4) << "handle UnsqueezeEltwise fuse";
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_op, elementwise, fused_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, fused_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(unsqz_op, unsqz, fused_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(unsqz_out, unsqz_out, fused_pattern);
|
||||
|
||||
size_t eltwise_in_x_rank = (subgraph.at(x)->Var()->GetShape()).size();
|
||||
size_t unsqz_in_rank = (subgraph.at(y)->Var()->GetShape()).size();
|
||||
std::vector<int> unsqz_op_axes =
|
||||
BOOST_GET_CONST(std::vector<int>, unsqz_op->Op()->GetAttr("axes"));
|
||||
int eltwise_op_axis =
|
||||
BOOST_GET_CONST(int, eltwise_op->Op()->GetAttr("axis"));
|
||||
|
||||
if (eltwise_in_x_rank == 4 && unsqz_in_rank == 2 &&
|
||||
unsqz_op_axes == std::vector<int>{2, 3} && eltwise_op_axis == -1) {
|
||||
eltwise_op->Op()->SetAttr("axis", 0);
|
||||
eltwise_op->Op()->SetInput("Y", {subgraph.at(y)->Name()});
|
||||
IR_NODE_LINK_TO(subgraph.at(x), eltwise_op);
|
||||
IR_NODE_LINK_TO(subgraph.at(y), eltwise_op);
|
||||
IR_NODE_LINK_TO(eltwise_op, eltwise_out);
|
||||
GraphSafeRemoveNodes(graph, {unsqz_op, unsqz_out});
|
||||
found_subgraph_count++;
|
||||
}
|
||||
};
|
||||
|
||||
gpd(graph, handler);
|
||||
AddStatis(found_subgraph_count);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(unsqueeze2_eltwise_fuse_pass,
|
||||
paddle::framework::ir::UnsqueezeEltwiseFusePass);
|
||||
REGISTER_PASS_CAPABILITY(unsqueeze2_eltwise_fuse_pass)
|
||||
.AddCombination(
|
||||
paddle::framework::compatible::OpVersionComparatorCombination()
|
||||
.EQ("unsqueeze2", 0)
|
||||
.EQ("elementwise_mul", 0));
|
@ -0,0 +1,45 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class Graph;
|
||||
|
||||
// |(rank 4) |(rank 2) |(rank 4) |(rank 2)
|
||||
// | unsqueeze2(axes=[2,3]) | |
|
||||
// | | fuse \ /
|
||||
// |------elementwise_mul(axis=-1) -> elementwise_mul(axis=0)
|
||||
// | |
|
||||
// | |
|
||||
//
|
||||
// Notice:
|
||||
// the rank of input is obtained from var_desc,
|
||||
// it maybe change in runtime.
|
||||
class UnsqueezeEltwiseFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~UnsqueezeEltwiseFusePass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,65 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
#include "paddle/fluid/framework/op_version_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
TEST(UnsqueezeEltwiseFusePass, basic) {
|
||||
Layers layers;
|
||||
auto* x = layers.data("x", {1, 92, 28, 28});
|
||||
auto* y = layers.data("y", {1, 92});
|
||||
std::vector<int> axes{2, 3};
|
||||
auto* unsqz_out = layers.unsqueeze2(y, axes);
|
||||
AttributeMap attrs;
|
||||
attrs["axis"] = -1;
|
||||
layers.elementwise_mul(x, unsqz_out, nullptr, &attrs);
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
|
||||
auto pass = PassRegistry::Instance().Get("unsqueeze2_eltwise_fuse_pass");
|
||||
int num_nodes_before = graph->Nodes().size();
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
int num_nodes_after = graph->Nodes().size();
|
||||
int num_fused_nodes_after = GetNumOpNodes(graph, "elementwise_mul");
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 2,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The number of nodes before and after the fuse does "
|
||||
"not meet expectations"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
num_fused_nodes_after, 1,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The number of fusion nodes does not meet expectations after fuse"));
|
||||
}
|
||||
|
||||
TEST(UnsqueezeEltwiseFusePass, pass_op_version_check) {
|
||||
ASSERT_TRUE(
|
||||
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
|
||||
.IsPassCompatible("unsqueeze2_eltwise_fuse_pass"));
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(unsqueeze2_eltwise_fuse_pass);
|
Loading…
Reference in new issue