|
|
|
@ -54,11 +54,13 @@ namespace paddle {
|
|
|
|
|
namespace inference {
|
|
|
|
|
|
|
|
|
|
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
|
|
|
|
|
const auto *analysis_config =
|
|
|
|
|
reinterpret_cast<const contrib::AnalysisConfig *>(config);
|
|
|
|
|
if (use_analysis) {
|
|
|
|
|
LOG(INFO) << *reinterpret_cast<const contrib::AnalysisConfig *>(config);
|
|
|
|
|
LOG(INFO) << *analysis_config;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << *reinterpret_cast<const NativeConfig *>(config);
|
|
|
|
|
LOG(INFO) << analysis_config->ToNativeConfig();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CompareResult(const std::vector<PaddleTensor> &outputs,
|
|
|
|
@ -96,12 +98,13 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<PaddlePredictor> CreateTestPredictor(
|
|
|
|
|
const PaddlePredictor::Config *config, bool use_analysis = true) {
|
|
|
|
|
const auto *analysis_config =
|
|
|
|
|
reinterpret_cast<const contrib::AnalysisConfig *>(config);
|
|
|
|
|
if (use_analysis) {
|
|
|
|
|
return CreatePaddlePredictor<contrib::AnalysisConfig>(
|
|
|
|
|
*(reinterpret_cast<const contrib::AnalysisConfig *>(config)));
|
|
|
|
|
return CreatePaddlePredictor<contrib::AnalysisConfig>(*analysis_config);
|
|
|
|
|
}
|
|
|
|
|
return CreatePaddlePredictor<NativeConfig>(
|
|
|
|
|
*(reinterpret_cast<const NativeConfig *>(config)));
|
|
|
|
|
auto native_config = analysis_config->ToNativeConfig();
|
|
|
|
|
return CreatePaddlePredictor<NativeConfig>(native_config);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetSize(const PaddleTensor &out) { return VecReduceToInt(out.shape); }
|
|
|
|
@ -328,10 +331,7 @@ void CompareNativeAndAnalysis(
|
|
|
|
|
const std::vector<std::vector<PaddleTensor>> &inputs) {
|
|
|
|
|
PrintConfig(config, true);
|
|
|
|
|
std::vector<PaddleTensor> native_outputs, analysis_outputs;
|
|
|
|
|
const auto *analysis_config =
|
|
|
|
|
reinterpret_cast<const contrib::AnalysisConfig *>(config);
|
|
|
|
|
auto native_config = analysis_config->ToNativeConfig();
|
|
|
|
|
TestOneThreadPrediction(&native_config, inputs, &native_outputs, false);
|
|
|
|
|
TestOneThreadPrediction(config, inputs, &native_outputs, false);
|
|
|
|
|
TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
|
|
|
|
|
CompareResult(analysis_outputs, native_outputs);
|
|
|
|
|
}
|
|
|
|
|