add MKL-DNN placement pass

This patch also refactors conv+bn (includes changes from PR
https://github.com/PaddlePaddle/Paddle/pull/13926)
updated to use the mkldnn-placement-pass.

test=develop
ce
Wojciech Uss 7 years ago committed by Michal Gallus
parent 0a9f5f1790
commit 5632019f0f

@ -226,18 +226,21 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
bool use_mkldnn = config_._use_mkldnn;
switch (config_.ir_mode) {
case contrib::AnalysisConfig::IrPassMode::kExclude:
Analyzer()
.IncludeAllIrPasses()
.SetUseMkldnn(config_._use_mkldnn)
.DisableIrPasses(config_.ir_passes)
.SetUseMkldnn(use_mkldnn)
.DisableIrPasses(use_mkldnn ? config_.ir_mkldnn_passes
: config_.ir_passes)
.Run(&argument_);
break;
case contrib::AnalysisConfig::IrPassMode::kInclude:
Analyzer()
.SetUseMkldnn(config_._use_mkldnn)
.IncludeIrPasses(config_.ir_passes)
.SetUseMkldnn(use_mkldnn)
.IncludeIrPasses(use_mkldnn ? config_.ir_mkldnn_passes
: config_.ir_passes)
.Run(&argument_);
break;
default:

@ -261,8 +261,8 @@ struct AnalysisConfig : public NativeConfig {
void SetIncludeMode() {
ir_mode = IrPassMode::kInclude;
// this pass has to be run at the beginning of all fuse passes
ir_passes = {"infer_clean_graph_pass"};
ir_mkldnn_passes = {"infer_clean_graph_pass"};
}
// Determine whether to perform graph optimization.
@ -271,6 +271,8 @@ struct AnalysisConfig : public NativeConfig {
IrPassMode ir_mode{IrPassMode::kExclude};
// passes to be excluded/included
std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"};
// passes to be excluded/included when MKL-DNN is enabled
std::vector<std::string> ir_mkldnn_passes{"embedding_fc_lstm_fuse_pass"};
// NOT stable yet.
bool use_feed_fetch_ops{true};

Loading…
Cancel
Save