[NGraph] Bert model for a capi, ngraph's support test=develop (#17844)

dependabot/pip/python/requests-2.20.0
mozga-intel 6 years ago committed by tensor-tang
parent 83e51ded21
commit c1379bf238

@ -146,7 +146,7 @@ bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); }
void profile(bool use_mkldnn = false) {
void profile(bool use_mkldnn = false, bool use_ngraph = false) {
AnalysisConfig config;
SetConfig(&config);
@ -155,6 +155,10 @@ void profile(bool use_mkldnn = false) {
config.pass_builder()->AppendPass("fc_mkldnn_pass");
}
if (use_ngraph) {
config.EnableNgraph();
}
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
@ -164,7 +168,11 @@ void profile(bool use_mkldnn = false) {
TEST(Analyzer_bert, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_bert, profile_mkldnn) { profile(true); }
TEST(Analyzer_bert, profile_mkldnn) { profile(true, false); }
#endif
#ifdef PADDLE_WITH_NGRAPH
TEST(Analyzer_bert, profile_ngraph) { profile(false, true); }
#endif
// Check the fuse status
@ -179,7 +187,7 @@ TEST(Analyzer_bert, fuse_statis) {
}
// Compare result of NativeConfig and AnalysisConfig
void compare(bool use_mkldnn = false) {
void compare(bool use_mkldnn = false, bool use_ngraph = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
if (use_mkldnn) {
@ -187,6 +195,10 @@ void compare(bool use_mkldnn = false) {
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
if (use_ngraph) {
cfg.EnableNgraph();
}
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
CompareNativeAndAnalysis(
@ -195,7 +207,15 @@ void compare(bool use_mkldnn = false) {
TEST(Analyzer_bert, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_bert, compare_mkldnn) { compare(true /* use_mkldnn */); }
TEST(Analyzer_bert, compare_mkldnn) {
compare(true, false /* use_mkldnn, no use_ngraph */);
}
#endif
#ifdef PADDLE_WITH_NGRAPH
TEST(Analyzer_bert, compare_ngraph) {
compare(false, true /* no use_mkldnn, use_ngraph */);
}
#endif
// Compare Deterministic result

Loading…
Cancel
Save