|
|
|
@ -146,7 +146,7 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
|
|
|
|
|
|
|
|
|
|
void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); }
|
|
|
|
|
|
|
|
|
|
void profile(bool use_mkldnn = false) {
|
|
|
|
|
void profile(bool use_mkldnn = false, bool use_ngraph = false) {
|
|
|
|
|
AnalysisConfig config;
|
|
|
|
|
SetConfig(&config);
|
|
|
|
|
|
|
|
|
@ -155,6 +155,10 @@ void profile(bool use_mkldnn = false) {
|
|
|
|
|
config.pass_builder()->AppendPass("fc_mkldnn_pass");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (use_ngraph) {
|
|
|
|
|
config.EnableNgraph();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<PaddleTensor>> outputs;
|
|
|
|
|
std::vector<std::vector<PaddleTensor>> inputs;
|
|
|
|
|
LoadInputData(&inputs);
|
|
|
|
@ -164,7 +168,11 @@ void profile(bool use_mkldnn = false) {
|
|
|
|
|
|
|
|
|
|
TEST(Analyzer_bert, profile) { profile(); }
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
TEST(Analyzer_bert, profile_mkldnn) { profile(true); }
|
|
|
|
|
TEST(Analyzer_bert, profile_mkldnn) { profile(true, false); }
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_NGRAPH
|
|
|
|
|
TEST(Analyzer_bert, profile_ngraph) { profile(false, true); }
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// Check the fuse status
|
|
|
|
@ -179,7 +187,7 @@ TEST(Analyzer_bert, fuse_statis) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Compare result of NativeConfig and AnalysisConfig
|
|
|
|
|
void compare(bool use_mkldnn = false) {
|
|
|
|
|
void compare(bool use_mkldnn = false, bool use_ngraph = false) {
|
|
|
|
|
AnalysisConfig cfg;
|
|
|
|
|
SetConfig(&cfg);
|
|
|
|
|
if (use_mkldnn) {
|
|
|
|
@ -187,6 +195,10 @@ void compare(bool use_mkldnn = false) {
|
|
|
|
|
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (use_ngraph) {
|
|
|
|
|
cfg.EnableNgraph();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<PaddleTensor>> inputs;
|
|
|
|
|
LoadInputData(&inputs);
|
|
|
|
|
CompareNativeAndAnalysis(
|
|
|
|
@ -195,7 +207,15 @@ void compare(bool use_mkldnn = false) {
|
|
|
|
|
|
|
|
|
|
TEST(Analyzer_bert, compare) { compare(); }
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
TEST(Analyzer_bert, compare_mkldnn) { compare(true /* use_mkldnn */); }
|
|
|
|
|
TEST(Analyzer_bert, compare_mkldnn) {
|
|
|
|
|
compare(true, false /* use_mkldnn, no use_ngraph */);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_NGRAPH
|
|
|
|
|
TEST(Analyzer_bert, compare_ngraph) {
|
|
|
|
|
compare(false, true /* no use_mkldnn, use_ngraph */);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// Compare Deterministic result
|
|
|
|
|