From fbd3604cad8fdb3ad7fa2f6717395b1c40e6ecaf Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 3 Apr 2018 05:31:52 +0000 Subject: [PATCH 1/4] Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference. --- paddle/fluid/framework/executor.cc | 94 ++++++++++++------- paddle/fluid/framework/executor.h | 7 ++ .../test_inference_image_classification.cc | 4 +- paddle/fluid/inference/tests/test_helper.h | 20 +++- 4 files changed, 85 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 64c06687b6..009d0fbeb8 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -129,13 +129,15 @@ static bool has_feed_operators( feed_count, feed_targets.size(), "The number of feed operators should match 'feed_targets'"); - // When feed operator are present, so should be feed_holder - auto var = block.FindVar(feed_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - feed_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, - "'%s' variable should be 'FEED_MINIBATCH' type", - feed_holder_name); + if (!feed_holder_name.empty()) { + // When feed operator are present, so should be feed_holder + auto var = block.FindVar(feed_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + feed_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, + "'%s' variable should be 'FEED_MINIBATCH' type", + feed_holder_name); + } } return feed_count > 0; @@ -169,13 +171,15 @@ static bool has_fetch_operators( fetch_count, fetch_targets.size(), "The number of fetch operators should match 'fetch_targets'"); - // When fetch operator are present, so should be fetch_holder - auto var = block.FindVar(fetch_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - fetch_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, - "'%s' variable should be 'FETCH_LIST' type", - fetch_holder_name); + if (!fetch_holder_name.empty()) { + // When fetch operator are present, so should be fetch_holder + auto var = block.FindVar(fetch_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + fetch_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, + "'%s' variable should be 'FETCH_LIST' type", + fetch_holder_name); + } } return fetch_count > 0; @@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - // map the data of feed_targets to feed_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFeedOpType) { - std::string feed_target_name = op->Output("Out")[0]; - int idx = boost::get(op->GetAttr("col")); - SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, - idx); - } - } - if (!has_fetch_ops) { // create fetch_holder variable auto* fetch_holder = global_block->Var(fetch_holder_name); @@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - Run(*copy_program, scope, 0, create_vars, create_vars); - - // obtain the data of fetch_targets from fetch_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFetchOpType) { - std::string fetch_target_name = op->Input("X")[0]; - int idx = boost::get(op->GetAttr("col")); - *fetch_targets[fetch_target_name] = - GetFetchVariable(*scope, fetch_holder_name, idx); - } - } + auto ctx = Prepare(*copy_program, 0); + RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, + feed_holder_name, fetch_holder_name, create_vars); } std::unique_ptr Executor::Prepare( @@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } +void Executor::RunPreparedContext( + ExecutorPrepareContext* ctx, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, + const std::string& feed_holder_name, const std::string& fetch_holder_name, + bool create_vars) { + auto& global_block = ctx->prog_.Block(ctx->block_id_); + + // map the data of feed_targets to feed_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFeedOpType) { + std::string feed_target_name = op->Output("Out")[0]; + PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(), + "Variable %s is not feeded."); + + int idx = boost::get(op->GetAttr("col")); + SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, + idx); + } + } + + RunPreparedContext(ctx, scope, create_vars, create_vars); + + // obtain the data of fetch_targets from fetch_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFetchOpType) { + std::string fetch_target_name = op->Input("X")[0]; + PADDLE_ENFORCE( + fetch_targets.find(fetch_target_name) != fetch_targets.end(), + "Variable %s is not fetched."); + + int idx = boost::get(op->GetAttr("col")); + *fetch_targets[fetch_target_name] = + GetFetchVariable(*scope, fetch_holder_name, idx); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 7173c51c95..b0e64d5de0 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -65,6 +65,13 @@ class Executor { bool create_local_scope = true, bool create_vars = true); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, + const std::string& feed_holder_name = "feed", + const std::string& fetch_holder_name = "fetch", + bool create_vars = true); + private: const platform::Place place_; }; diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index e9a27171f1..9126efb8c2 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -48,7 +48,7 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; - TestInference( + TestInference( dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); @@ -59,7 +59,7 @@ TEST(inference, image_classification) { // Run inference on CUDA GPU LOG(INFO) << "--- GPU Runs: ---"; - TestInference( + TestInference( dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index dce541c097..d559cc7d03 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -88,7 +88,7 @@ void CheckError(paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template +template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, std::vector& cpu_fetchs, @@ -170,7 +170,14 @@ void TestInference(const std::string& dirname, // 6. Run the inference program { // Ignore the profiling results of the first run - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + std::unique_ptr ctx; + if (PrepareContext) { + ctx = executor.Prepare(*inference_program, 0); + executor.RunPreparedContext( + ctx.get(), scope, feed_targets, fetch_targets); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets); + } // Enable the profiler paddle::platform::EnableProfiler(state); @@ -181,7 +188,14 @@ void TestInference(const std::string& dirname, "run_inference", paddle::platform::DeviceContextPool::Instance().Get(place)); - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + if (PrepareContext) { + // Note: if you changed the inference_program, you need to call + // executor.Prepare() again to get a new ExecutorPrepareContext. + executor.RunPreparedContext( + ctx.get(), scope, feed_targets, fetch_targets); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets); + } } // Disable the profiler and print the timing information From a9e826ed495bcd5a5b625d4ce364c8c42d0d0b7d Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sun, 8 Apr 2018 06:32:30 +0000 Subject: [PATCH 2/4] Add the check of has_feed/fetch_operators back. --- paddle/fluid/framework/executor.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 8a0ab118d0..3edaede8d6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -352,13 +352,17 @@ void Executor::RunPreparedContext( bool create_vars) { auto& global_block = ctx->prog_.Block(ctx->block_id_); + PADDLE_ENFORCE( + has_feed_operators(global_block, feed_targets, feed_holder_name), + "Program in ExecutorPrepareContext should has feed_ops."); + PADDLE_ENFORCE( + has_fetch_operators(global_block, fetch_targets, fetch_holder_name), + "Program in the prepared context should has fetch_ops."); + // map the data of feed_targets to feed_holder for (auto* op : global_block.AllOps()) { if (op->Type() == kFeedOpType) { std::string feed_target_name = op->Output("Out")[0]; - PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(), - "Variable %s is not feeded."); - int idx = boost::get(op->GetAttr("col")); SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, idx); @@ -371,10 +375,6 @@ void Executor::RunPreparedContext( for (auto* op : global_block.AllOps()) { if (op->Type() == kFetchOpType) { std::string fetch_target_name = op->Input("X")[0]; - PADDLE_ENFORCE( - fetch_targets.find(fetch_target_name) != fetch_targets.end(), - "Variable %s is not fetched."); - int idx = boost::get(op->GetAttr("col")); *fetch_targets[fetch_target_name] = GetFetchVariable(*scope, fetch_holder_name, idx); From 339be6254ea5e3432e4cbe44f35609bb45662e12 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 12 Apr 2018 05:58:26 +0000 Subject: [PATCH 3/4] Refine the order of arguments. --- paddle/fluid/framework/executor.cc | 5 ++--- paddle/fluid/framework/executor.h | 4 ++-- paddle/fluid/inference/tests/test_helper.h | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 910012927b..34bba77f40 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -359,9 +359,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext( ExecutorPrepareContext* ctx, Scope* scope, std::map& feed_targets, - std::map& fetch_targets, - const std::string& feed_holder_name, const std::string& fetch_holder_name, - bool create_vars) { + std::map& fetch_targets, bool create_vars, + const std::string& feed_holder_name, const std::string& fetch_holder_name) { auto& global_block = ctx->prog_.Block(ctx->block_id_); PADDLE_ENFORCE( diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index cbd70d9544..8b3ea01542 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -73,9 +73,9 @@ class Executor { void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, std::map& feed_targets, std::map& fetch_targets, + bool create_vars = true, const std::string& feed_holder_name = "feed", - const std::string& fetch_holder_name = "fetch", - bool create_vars = true); + const std::string& fetch_holder_name = "fetch"); private: const platform::Place place_; diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 09fe344ec7..9875e43860 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -178,8 +178,8 @@ void TestInference(const std::string& dirname, std::unique_ptr ctx; if (PrepareContext) { ctx = executor.Prepare(*inference_program, 0); - executor.RunPreparedContext(ctx.get(), scope, feed_targets, - fetch_targets); + executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, + CreateVars); } else { executor.Run(*inference_program, scope, feed_targets, fetch_targets, CreateVars); @@ -198,7 +198,7 @@ void TestInference(const std::string& dirname, // Note: if you changed the inference_program, you need to call // executor.Prepare() again to get a new ExecutorPrepareContext. executor.RunPreparedContext(ctx.get(), scope, feed_targets, - fetch_targets); + fetch_targets, CreateVars); } else { executor.Run(*inference_program, scope, feed_targets, fetch_targets, CreateVars); From 449bdde58accc9beb94d56c8ef33c0bde4c007b7 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 12 Apr 2018 06:15:24 +0000 Subject: [PATCH 4/4] Correct some typos. --- cmake/cblas.cmake | 2 +- paddle/fluid/framework/executor.cc | 19 +++++++++++-------- paddle/fluid/framework/executor.h | 3 +++ paddle/fluid/inference/io.cc | 2 +- paddle/fluid/inference/tests/test_helper.h | 2 +- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index 52a22c1fbf..e3b9d94215 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING) /usr/lib/reference/ ) else() - # Diable the finding of reference cblas under host's system path + # Disable the finding of reference cblas under host's system path set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include) set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib) endif() diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 34bba77f40..513e720fd0 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name, if (tensor.memory_size() == 0) { return; } - if (tensor.type().hash_code() != typeid(float).hash_code() && - tensor.type().hash_code() != typeid(double).hash_code()) { + if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT + tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT return; } PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), @@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, // Return true if the block has feed operators and holder of matching info. static bool has_feed_operators( const BlockDesc& block, - std::map& feed_targets, + const std::map& feed_targets, const std::string& feed_holder_name) { size_t feed_count = 0; for (auto* op : block.AllOps()) { if (op->Type() == kFeedOpType) { feed_count++; + // The input variable's name of feed_op should be feed_holder_name. PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, "Input to feed op should be '%s'", feed_holder_name); std::string feed_target_name = op->Output("Out")[0]; @@ -167,7 +168,7 @@ static bool has_feed_operators( "The number of feed operators should match 'feed_targets'"); if (!feed_holder_name.empty()) { - // When feed operator are present, so should be feed_holder + // When feed operator are present, so should be feed_holder. auto var = block.FindVar(feed_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", feed_holder_name); @@ -187,12 +188,14 @@ static bool has_feed_operators( // and fetch_holder_name. Raise exception when any mismatch is found. // Return true if the block has fetch operators and holder of matching info. static bool has_fetch_operators( - const BlockDesc& block, std::map& fetch_targets, + const BlockDesc& block, + const std::map& fetch_targets, const std::string& fetch_holder_name) { size_t fetch_count = 0; for (auto* op : block.AllOps()) { if (op->Type() == kFetchOpType) { fetch_count++; + // The output variable's name of fetch_op should be fetch_holder_name. PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, "Output of fetch op should be '%s'", fetch_holder_name); std::string fetch_target_name = op->Input("X")[0]; @@ -209,7 +212,7 @@ static bool has_fetch_operators( "The number of fetch operators should match 'fetch_targets'"); if (!fetch_holder_name.empty()) { - // When fetch operator are present, so should be fetch_holder + // When fetch operator are present, so should be fetch_holder. auto var = block.FindVar(fetch_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", fetch_holder_name); @@ -287,8 +290,8 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } auto ctx = Prepare(*copy_program, 0); - RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, - feed_holder_name, fetch_holder_name, create_vars); + RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars, + feed_holder_name, fetch_holder_name); } std::unique_ptr Executor::Prepare( diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 8b3ea01542..43defdacf2 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#include +#include +#include #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index a29d457b6f..3b58019db6 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -23,7 +23,7 @@ limitations under the License. */ namespace paddle { namespace inference { -// Temporarilly add this function for exposing framework::InitDevices() when +// Temporarily add this function for exposing framework::InitDevices() when // linking the inference shared library. void Init(bool init_p2p) { framework::InitDevices(init_p2p); } diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 9875e43860..c3a8d0889c 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -195,7 +195,7 @@ void TestInference(const std::string& dirname, paddle::platform::DeviceContextPool::Instance().Get(place)); if (PrepareContext) { - // Note: if you changed the inference_program, you need to call + // Note: if you change the inference_program, you need to call // executor.Prepare() again to get a new ExecutorPrepareContext. executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, CreateVars);