diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index 52a22c1fbf..e3b9d94215 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING) /usr/lib/reference/ ) else() - # Diable the finding of reference cblas under host's system path + # Disable the finding of reference cblas under host's system path set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include) set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib) endif() diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index c2ca1bbc78..513e720fd0 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name, if (tensor.memory_size() == 0) { return; } - if (tensor.type().hash_code() != typeid(float).hash_code() && - tensor.type().hash_code() != typeid(double).hash_code()) { + if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT + tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT return; } PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), @@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, // Return true if the block has feed operators and holder of matching info. static bool has_feed_operators( const BlockDesc& block, - std::map<std::string, const LoDTensor*>& feed_targets, + const std::map<std::string, const LoDTensor*>& feed_targets, const std::string& feed_holder_name) { size_t feed_count = 0; for (auto* op : block.AllOps()) { if (op->Type() == kFeedOpType) { feed_count++; + // The input variable's name of feed_op should be feed_holder_name. PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, "Input to feed op should be '%s'", feed_holder_name); std::string feed_target_name = op->Output("Out")[0]; @@ -166,13 +167,15 @@ static bool has_feed_operators( feed_count, feed_targets.size(), "The number of feed operators should match 'feed_targets'"); - // When feed operator are present, so should be feed_holder - auto var = block.FindVar(feed_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - feed_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, - "'%s' variable should be 'FEED_MINIBATCH' type", - feed_holder_name); + if (!feed_holder_name.empty()) { + // When feed operator are present, so should be feed_holder. + auto var = block.FindVar(feed_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + feed_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, + "'%s' variable should be 'FEED_MINIBATCH' type", + feed_holder_name); + } } return feed_count > 0; @@ -185,12 +188,14 @@ static bool has_feed_operators( // and fetch_holder_name. Raise exception when any mismatch is found. // Return true if the block has fetch operators and holder of matching info. static bool has_fetch_operators( - const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets, + const BlockDesc& block, + const std::map<std::string, LoDTensor*>& fetch_targets, const std::string& fetch_holder_name) { size_t fetch_count = 0; for (auto* op : block.AllOps()) { if (op->Type() == kFetchOpType) { fetch_count++; + // The output variable's name of fetch_op should be fetch_holder_name. PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, "Output of fetch op should be '%s'", fetch_holder_name); std::string fetch_target_name = op->Input("X")[0]; @@ -206,13 +211,15 @@ static bool has_fetch_operators( fetch_count, fetch_targets.size(), "The number of fetch operators should match 'fetch_targets'"); - // When fetch operator are present, so should be fetch_holder - auto var = block.FindVar(fetch_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - fetch_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, - "'%s' variable should be 'FETCH_LIST' type", - fetch_holder_name); + if (!fetch_holder_name.empty()) { + // When fetch operator are present, so should be fetch_holder. + auto var = block.FindVar(fetch_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + fetch_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, + "'%s' variable should be 'FETCH_LIST' type", + fetch_holder_name); + } } return fetch_count > 0; @@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - // map the data of feed_targets to feed_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFeedOpType) { - std::string feed_target_name = op->Output("Out")[0]; - int idx = boost::get<int>(op->GetAttr("col")); - SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, - idx); - } - } - if (!has_fetch_ops) { // create fetch_holder variable auto* fetch_holder = global_block->Var(fetch_holder_name); @@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - Run(*copy_program, scope, 0, create_vars, create_vars); - - // obtain the data of fetch_targets from fetch_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFetchOpType) { - std::string fetch_target_name = op->Input("X")[0]; - int idx = boost::get<int>(op->GetAttr("col")); - *fetch_targets[fetch_target_name] = - GetFetchVariable(*scope, fetch_holder_name, idx); - } - } + auto ctx = Prepare(*copy_program, 0); + RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars, + feed_holder_name, fetch_holder_name); } std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( @@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } +void Executor::RunPreparedContext( + ExecutorPrepareContext* ctx, Scope* scope, + std::map<std::string, const LoDTensor*>& feed_targets, + std::map<std::string, LoDTensor*>& fetch_targets, bool create_vars, + const std::string& feed_holder_name, const std::string& fetch_holder_name) { + auto& global_block = ctx->prog_.Block(ctx->block_id_); + + PADDLE_ENFORCE( + has_feed_operators(global_block, feed_targets, feed_holder_name), + "Program in ExecutorPrepareContext should has feed_ops."); + PADDLE_ENFORCE( + has_fetch_operators(global_block, fetch_targets, fetch_holder_name), + "Program in the prepared context should has fetch_ops."); + + // map the data of feed_targets to feed_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFeedOpType) { + std::string feed_target_name = op->Output("Out")[0]; + int idx = boost::get<int>(op->GetAttr("col")); + SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, + idx); + } + } + + RunPreparedContext(ctx, scope, create_vars, create_vars); + + // obtain the data of fetch_targets from fetch_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFetchOpType) { + std::string fetch_target_name = op->Input("X")[0]; + int idx = boost::get<int>(op->GetAttr("col")); + *fetch_targets[fetch_target_name] = + GetFetchVariable(*scope, fetch_holder_name, idx); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 75b29b2f40..43defdacf2 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#include <map> +#include <string> +#include <vector> #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -70,6 +73,13 @@ class Executor { bool create_local_scope = true, bool create_vars = true); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, + std::map<std::string, const LoDTensor*>& feed_targets, + std::map<std::string, LoDTensor*>& fetch_targets, + bool create_vars = true, + const std::string& feed_holder_name = "feed", + const std::string& fetch_holder_name = "fetch"); + private: const platform::Place place_; }; diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index a29d457b6f..3b58019db6 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -23,7 +23,7 @@ limitations under the License. */ namespace paddle { namespace inference { -// Temporarilly add this function for exposing framework::InitDevices() when +// Temporarily add this function for exposing framework::InitDevices() when // linking the inference shared library. void Init(bool init_p2p) { framework::InitDevices(init_p2p); } diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index ca2077d074..1e6555bb02 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -46,8 +46,8 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; - TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds, - cpu_fetchs1, FLAGS_repeat); + TestInference<paddle::platform::CPUPlace, false, true>( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -57,8 +57,8 @@ TEST(inference, image_classification) { // Run inference on CUDA GPU LOG(INFO) << "--- GPU Runs: ---"; - TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds, - cpu_fetchs2, FLAGS_repeat); + TestInference<paddle::platform::CUDAPlace, false, true>( + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); CheckError<float>(output1, output2); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 064e400f0c..c3a8d0889c 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -89,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template <typename Place, bool CreateVars = true> +template <typename Place, bool CreateVars = true, bool PrepareContext = false> void TestInference(const std::string& dirname, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, @@ -175,8 +175,15 @@ void TestInference(const std::string& dirname, } // Ignore the profiling results of the first run - executor.Run(*inference_program, scope, feed_targets, fetch_targets, - CreateVars); + std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx; + if (PrepareContext) { + ctx = executor.Prepare(*inference_program, 0); + executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, + CreateVars); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); + } // Enable the profiler paddle::platform::EnableProfiler(state); @@ -187,8 +194,15 @@ void TestInference(const std::string& dirname, "run_inference", paddle::platform::DeviceContextPool::Instance().Get(place)); - executor.Run(*inference_program, scope, feed_targets, fetch_targets, - CreateVars); + if (PrepareContext) { + // Note: if you change the inference_program, you need to call + // executor.Prepare() again to get a new ExecutorPrepareContext. + executor.RunPreparedContext(ctx.get(), scope, feed_targets, + fetch_targets, CreateVars); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); + } } // Disable the profiler and print the timing information