fix trt dynamic ernie serialization unit test (#26228)

revert-24895-update_cub
Pei Yang 5 years ago committed by GitHub
parent ea6716a55b
commit b757466b0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -309,7 +309,8 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() {
BriefNode *brief_node = itr.second;
if (!Agent(brief_node->node).marked()) {
VLOG(4) << brief_node->node->id() << " node not a trt candidate.";
VLOG(4) << brief_node->node->id() << " node named "
<< brief_node->node->Name() << " is not a trt candidate.";
continue;
}

@ -471,19 +471,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif()
inference_analysis_test(test_trt_dynamic_shape_ernie_serialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc
inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
set(TEST_TRT_ERNIE_SER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_serialized/")
if (NOT EXISTS ${TEST_TRT_ERNIE_SER_MODEL})
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_serialized.tgz")
endif()
inference_analysis_test(test_trt_dynamic_shape_ernie_deserialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_serialized)
endif()
set(LITE_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lite")

@ -123,8 +123,11 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
AnalysisConfig* config_deser = new AnalysisConfig(config);
std::vector<float> out_data;
run(config, &out_data);
run(config, &out_data); // serialize
run(*config_deser, &out_data); // deserialize
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6);
}

@ -126,7 +126,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
std::vector<float> out_data;
run(config, &out_data);
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6);
EXPECT_NEAR(result[i], out_data[i], 1e-5);
}
}

Loading…
Cancel
Save