|
|
|
@ -22,6 +22,8 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/inference/io.h"
|
|
|
|
|
#include "paddle/fluid/platform/profiler.h"
|
|
|
|
|
|
|
|
|
|
DECLARE_bool(use_mkldnn);
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SetupTensor(paddle::framework::LoDTensor* input,
|
|
|
|
|
paddle::framework::DDim dims, T lower, T upper) {
|
|
|
|
@ -133,24 +135,11 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
|
|
|
|
|
return feed_target_shapes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void EnableMKLDNN(
|
|
|
|
|
const std::unique_ptr<paddle::framework::ProgramDesc>& program) {
|
|
|
|
|
for (size_t bid = 0; bid < program->Size(); ++bid) {
|
|
|
|
|
auto* block = program->MutableBlock(bid);
|
|
|
|
|
for (auto* op : block->AllOps()) {
|
|
|
|
|
if (op->HasAttr("use_mkldnn")) {
|
|
|
|
|
op->SetAttr("use_mkldnn", true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Place, bool CreateVars = true, bool PrepareContext = false>
|
|
|
|
|
void TestInference(const std::string& dirname,
|
|
|
|
|
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
|
|
|
|
|
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
|
|
|
|
|
const int repeat = 1, const bool is_combined = false,
|
|
|
|
|
const bool use_mkldnn = false) {
|
|
|
|
|
const int repeat = 1, const bool is_combined = false) {
|
|
|
|
|
// 1. Define place, executor, scope
|
|
|
|
|
auto place = Place();
|
|
|
|
|
auto executor = paddle::framework::Executor(place);
|
|
|
|
@ -182,9 +171,6 @@ void TestInference(const std::string& dirname,
|
|
|
|
|
"init_program",
|
|
|
|
|
paddle::platform::DeviceContextPool::Instance().Get(place));
|
|
|
|
|
inference_program = InitProgram(&executor, scope, dirname, is_combined);
|
|
|
|
|
if (use_mkldnn) {
|
|
|
|
|
EnableMKLDNN(inference_program);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Disable the profiler and print the timing information
|
|
|
|
|
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
|
|
|
|
@ -210,7 +196,10 @@ void TestInference(const std::string& dirname,
|
|
|
|
|
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 6. Run the inference program
|
|
|
|
|
// 6. If export Flags_use_mkldnn=True, use mkldnn related ops.
|
|
|
|
|
if (FLAGS_use_mkldnn) executor.EnableMKLDNN(*inference_program);
|
|
|
|
|
|
|
|
|
|
// 7. Run the inference program
|
|
|
|
|
{
|
|
|
|
|
if (!CreateVars) {
|
|
|
|
|
// If users don't want to create and destroy variables every time they
|
|
|
|
|