fix trt_dynamic_shape_ernie_deserialize_test (#27290)

* fix trt_dynamic_shape_ernie_deserialize_test

* support when opt cache dir does not exist
ut_move_night
Pei Yang 5 years ago committed by GitHub
parent 1483ea2304
commit 3ae3b86489
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -480,10 +480,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz") inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif() endif()
# disable test_trt_dynamic_shape_ernie_ser_deser temporary inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
#inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
# EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
# ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
endif() endif()

@ -12,15 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <dirent.h>
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <unistd.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h" #include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
int DeleteCache(std::string path) {
DIR* dir = opendir(path.c_str());
if (dir == NULL) return 0;
struct dirent* ptr;
while ((ptr = readdir(dir)) != NULL) {
if (std::strcmp(ptr->d_name, ".") == 0 ||
std::strcmp(ptr->d_name, "..") == 0) {
continue;
} else if (ptr->d_type == 8) {
std::string file_rm = path + "/" + ptr->d_name;
return remove(file_rm.c_str());
}
}
return 0;
}
void run(const AnalysisConfig& config, std::vector<float>* out_data) { void run(const AnalysisConfig& config, std::vector<float>* out_data) {
auto predictor = CreatePaddlePredictor(config); auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames(); auto input_names = predictor->GetInputNames();
@ -86,6 +104,11 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) {
void trt_ernie(bool with_fp16, std::vector<float> result) { void trt_ernie(bool with_fp16, std::vector<float> result) {
AnalysisConfig config; AnalysisConfig config;
std::string model_dir = FLAGS_infer_model; std::string model_dir = FLAGS_infer_model;
// Delete serialization cache to perform serialization first rather than
// deserialization.
std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache";
DeleteCache(opt_cache_dir);
SetConfig(&config, model_dir, true /* use_gpu */); SetConfig(&config, model_dir, true /* use_gpu */);
config.SwitchUseFeedFetchOps(false); config.SwitchUseFeedFetchOps(false);

Loading…
Cancel
Save