|
|
|
@ -178,6 +178,7 @@ void TestLACPrediction(const std::string &model_path,
|
|
|
|
|
cfg.device = 0;
|
|
|
|
|
cfg.specify_input_name = true;
|
|
|
|
|
cfg.enable_ir_optim = true;
|
|
|
|
|
cfg.ir_passes.push_back("fc_gru_fuse_pass");
|
|
|
|
|
predictor =
|
|
|
|
|
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(cfg);
|
|
|
|
|
} else {
|
|
|
|
@ -208,13 +209,6 @@ void TestLACPrediction(const std::string &model_path,
|
|
|
|
|
PrintTime(timer.toc(), batch_size, repeat);
|
|
|
|
|
|
|
|
|
|
// check result
|
|
|
|
|
if (use_analysis) {
|
|
|
|
|
// run once for comparion as reference
|
|
|
|
|
auto ref_predictor =
|
|
|
|
|
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
|
|
|
|
|
ref_predictor->Run(input_slots, &ref_outputs_slots);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(outputs_slots.size(), 1UL);
|
|
|
|
|
auto &out = outputs_slots[0];
|
|
|
|
|
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
|
|
|
|
@ -228,6 +222,10 @@ void TestLACPrediction(const std::string &model_path,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (use_analysis) {
|
|
|
|
|
// run once for comparion as reference
|
|
|
|
|
auto ref_predictor =
|
|
|
|
|
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
|
|
|
|
|
ref_predictor->Run(input_slots, &ref_outputs_slots);
|
|
|
|
|
EXPECT_EQ(ref_outputs_slots.size(), outputs_slots.size());
|
|
|
|
|
auto &ref_out = ref_outputs_slots[0];
|
|
|
|
|
size_t ref_size =
|
|
|
|
@ -256,12 +254,9 @@ void TestLACPrediction(const std::string &model_path,
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << "has num ops: " << num_ops;
|
|
|
|
|
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
|
|
|
|
|
ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
|
|
|
|
|
LOG(INFO) << "fc fuse num:" << fuse_statis.at("fc_fuse");
|
|
|
|
|
LOG(INFO) << "fc gru fuse num:" << fuse_statis.at("fc_gru_fuse");
|
|
|
|
|
|
|
|
|
|
// ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
|
|
|
|
|
// LOG(INFO) << fuse_statis.at("fc_gru_fuse");
|
|
|
|
|
LOG(INFO) << "fc fuse num:" << fuse_statis.at("fc_fuse");
|
|
|
|
|
// LOG(INFO) << "fc gru fuse num:" << fuse_statis.at("fc_gru_fuse");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|