From 53cfac94923a59c6b0bbd97cb7e396aff488733a Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Fri, 17 Apr 2020 15:47:57 +0800 Subject: [PATCH] Fix trt fc fuse test (#23852) * fix trt fc fuse test, test=develop * fix trt_transpose_flatten_concat shape, test=develop --- .../unittests/ir/inference/test_trt_fc_fuse_pass.py | 9 ++++++--- .../test_trt_transpose_flatten_concat_fuse_pass.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py index f66822171c..0f035d6026 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py @@ -25,7 +25,8 @@ from paddle.fluid.core import AnalysisConfig class FCFusePassTRTTest(InferencePassTest): def setUp(self): with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data(name="data", shape=[32, 128], dtype="float32") + data = fluid.data( + name="data", shape=[32, 128, 2, 2], dtype="float32") fc_out1 = fluid.layers.fc(input=data, size=128, num_flatten_dims=1, @@ -35,10 +36,12 @@ class FCFusePassTRTTest(InferencePassTest): num_flatten_dims=1) out = fluid.layers.softmax(input=fc_out2) - self.feeds = {"data": np.random.random((32, 128)).astype("float32")} + self.feeds = { + "data": np.random.random((32, 128, 2, 2)).astype("float32") + } self.enable_trt = True self.trt_parameters = FCFusePassTRTTest.TensorRTParam( - 1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False) + 1 << 30, 32, 3, AnalysisConfig.Precision.Float32, False, False) self.fetch_list = [out] def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py index c85c54c741..41f02b0427 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py @@ -42,7 +42,7 @@ class TransposeFlattenConcatFusePassTRTTest(InferencePassTest): } self.enable_trt = True self.trt_parameters = TransposeFlattenConcatFusePassTRTTest.TensorRTParam( - 1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False) + 1 << 20, 8, 3, AnalysisConfig.Precision.Float32, False, False) self.fetch_list = [out] def test_check_output(self):