enable mkldnn in infer api

fix-develop-build.sh
tensor-tang 7 years ago
parent 33c21291e7
commit 01f0f16884

@ -77,6 +77,9 @@ bool AnalysisPredictor::Init(
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_.use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());

@ -106,6 +106,9 @@ bool NativePaddlePredictor::Init(
}
ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_.use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);

@ -45,7 +45,7 @@ class PaddleBuf {
PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {}
// Own memory.
PaddleBuf(size_t length)
explicit PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes.
void Resize(size_t length);
@ -121,6 +121,8 @@ struct NativeConfig : public PaddlePredictor::Config {
bool use_gpu{false};
int device{0};
float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization.
// MKLDNN related fields.
bool use_mkldnn{false};
// Specify the variable's name of each input.
bool specify_input_name{false};

@ -66,12 +66,13 @@ Record ProcessALine(const std::string &line) {
* Use the native and analysis fluid engine to inference the demo.
* ocr, mobilenet and se_resnext50
*/
void TestVisualPrediction() {
void TestVisualPrediction(bool use_mkldnn) {
std::unique_ptr<PaddlePredictor> predictor;
AnalysisConfig cfg;
cfg.param_file = FLAGS_infer_model + "/__params__";
cfg.prog_file = FLAGS_infer_model + "/__model__";
cfg.use_gpu = false;
cfg.use_mkldnn = use_mkldnn;
cfg.device = 0;
cfg.enable_ir_optim = true;
cfg.ir_passes.push_back("fc_gru_fuse_pass");
@ -163,7 +164,10 @@ void TestVisualPrediction() {
}
}
TEST(Analyzer_vis, analysis) { TestVisualPrediction(); }
TEST(Analyzer_vis, analysis) { TestVisualPrediction(/*use_mkldnn*/ false); }
TEST(Analyzer_vis, analysis_mkldnn) {
TestVisualPrediction(/*use_mkldnn*/ true);
}
} // namespace analysis
} // namespace inference

Loading…
Cancel
Save