Fix trt fc fuse test (#23852)

* fix trt fc fuse test, test=develop

* fix trt_transpose_flatten_concat shape, test=develop
revert-22778-infer_var_type
liu zhengxi 5 years ago committed by GitHub
parent 477cb1fdb3
commit 53cfac9492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,8 @@ from paddle.fluid.core import AnalysisConfig
class FCFusePassTRTTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[32, 128], dtype="float32")
data = fluid.data(
name="data", shape=[32, 128, 2, 2], dtype="float32")
fc_out1 = fluid.layers.fc(input=data,
size=128,
num_flatten_dims=1,
@ -35,10 +36,12 @@ class FCFusePassTRTTest(InferencePassTest):
num_flatten_dims=1)
out = fluid.layers.softmax(input=fc_out2)
self.feeds = {"data": np.random.random((32, 128)).astype("float32")}
self.feeds = {
"data": np.random.random((32, 128, 2, 2)).astype("float32")
}
self.enable_trt = True
self.trt_parameters = FCFusePassTRTTest.TensorRTParam(
1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False)
1 << 30, 32, 3, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):

@ -42,7 +42,7 @@ class TransposeFlattenConcatFusePassTRTTest(InferencePassTest):
}
self.enable_trt = True
self.trt_parameters = TransposeFlattenConcatFusePassTRTTest.TensorRTParam(
1 << 20, 1, 3, AnalysisConfig.Precision.Float32, False, False)
1 << 20, 8, 3, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):

Loading…
Cancel
Save