|
|
|
@ -156,10 +156,27 @@ void TestInference(const std::string& dirname,
|
|
|
|
|
auto executor = paddle::framework::Executor(place);
|
|
|
|
|
auto* scope = new paddle::framework::Scope();
|
|
|
|
|
|
|
|
|
|
// Profile the performance
|
|
|
|
|
paddle::platform::ProfilerState state;
|
|
|
|
|
if (paddle::platform::is_cpu_place(place)) {
|
|
|
|
|
state = paddle::platform::ProfilerState::kCPU;
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
state = paddle::platform::ProfilerState::kAll;
|
|
|
|
|
// The default device_id of paddle::platform::CUDAPlace is 0.
|
|
|
|
|
// Users can get the device_id using:
|
|
|
|
|
// int device_id = place.GetDeviceId();
|
|
|
|
|
paddle::platform::SetDeviceId(0);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 2. Initialize the inference_program and load parameters
|
|
|
|
|
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
|
|
|
|
|
|
|
|
|
|
// Enable the profiler
|
|
|
|
|
paddle::platform::EnableProfiler(state);
|
|
|
|
|
{
|
|
|
|
|
paddle::platform::RecordEvent record_event(
|
|
|
|
|
"init_program",
|
|
|
|
@ -172,6 +189,10 @@ void TestInference(const std::string& dirname,
|
|
|
|
|
EnableMKLDNN(inference_program);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Disable the profiler and print the timing information
|
|
|
|
|
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
|
|
|
|
|
"load_program_profiler");
|
|
|
|
|
paddle::platform::ResetProfiler();
|
|
|
|
|
|
|
|
|
|
// 3. Get the feed_target_names and fetch_target_names
|
|
|
|
|
const std::vector<std::string>& feed_target_names =
|
|
|
|
@ -212,6 +233,9 @@ void TestInference(const std::string& dirname,
|
|
|
|
|
true, CreateVars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Enable the profiler
|
|
|
|
|
paddle::platform::EnableProfiler(state);
|
|
|
|
|
|
|
|
|
|
// Run repeat times to profile the performance
|
|
|
|
|
for (int i = 0; i < repeat; ++i) {
|
|
|
|
|
paddle::platform::RecordEvent record_event(
|
|
|
|
@ -228,6 +252,11 @@ void TestInference(const std::string& dirname,
|
|
|
|
|
CreateVars);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Disable the profiler and print the timing information
|
|
|
|
|
paddle::platform::DisableProfiler(
|
|
|
|
|
paddle::platform::EventSortingKey::kDefault, "run_inference_profiler");
|
|
|
|
|
paddle::platform::ResetProfiler();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delete scope;
|
|
|
|
|