diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index 2b138280fb..e1326f8896 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -17,12 +17,61 @@ namespace paddle { namespace framework { +#ifdef PADDLE_WITH_MKLDNN +using transfer_data_cache_map = std::unordered_map; +using transfer_scope_cache_map = std::unordered_set; +static std::unordered_map + static_transfer_data_caches; +static std::unordered_map + static_transfer_scope_caches; +#endif + std::unordered_map& global_transfer_data_cache() { +#ifdef PADDLE_WITH_MKLDNN + size_t sid = platform::get_cur_mkldnn_session_id(); + + // if there is specific mkldnn tid setting from user. + if (sid != platform::kMKLDNNSessionID_Default) { + sid = std::hash()(std::this_thread::get_id()); + + static std::mutex acquire_barrier; + std::lock_guard block_until_finish_this_job(acquire_barrier); + + auto map_it = static_transfer_data_caches.find(sid); + if (map_it == static_transfer_data_caches.end()) { + auto* x = new transfer_data_cache_map; + static_transfer_data_caches[sid] = x; + return *x; + } else { + return *static_transfer_data_caches[sid]; + } + } +#endif thread_local auto* x = new std::unordered_map; return *x; } std::unordered_set& global_transfer_scope_cache() { +#ifdef PADDLE_WITH_MKLDNN + size_t sid = platform::get_cur_mkldnn_session_id(); + + // if there is specific mkldnn session id setting from user. + if (sid != platform::kMKLDNNSessionID_Default) { + sid = std::hash()(std::this_thread::get_id()); + + static std::mutex acquire_barrier; + std::lock_guard block_until_finish_this_job(acquire_barrier); + + auto map_it = static_transfer_scope_caches.find(sid); + if (map_it == static_transfer_scope_caches.end()) { + auto* x = new transfer_scope_cache_map; + static_transfer_scope_caches[sid] = x; + return *x; + } else { + return *static_transfer_scope_caches[sid]; + } + } +#endif thread_local auto* x = new std::unordered_set; return *x; } diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index f679e12218..406c028a9f 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -230,7 +230,7 @@ TEST(Analyzer_bert, compare_determine) { inputs); } -TEST(Analyzer_bert, transfer_scope_cache) { +void verify_transfer_scope_cache(bool is_static = false) { AnalysisConfig config; SetConfig(&config); @@ -251,6 +251,11 @@ TEST(Analyzer_bert, transfer_scope_cache) { threads.emplace_back([&, i]() { std::getline(fin, line); ParseLine(line, &input); +#ifdef PADDLE_WITH_MKLDNN + // Use static method to handle transfer_scope_cache() + // TODO(intel) explicit session id setting will be deprecated. + if (is_static) platform::set_cur_mkldnn_session_id(1); +#endif predictor->Run(input, &output, FLAGS_batch_size); global_transfer_scope_cache.insert( &paddle::framework::global_transfer_scope_cache()); @@ -261,12 +266,31 @@ TEST(Analyzer_bert, transfer_scope_cache) { threads.clear(); std::vector().swap(input); } - // Since paddle::framework::global_transfer_scope_cache() and - // paddle::framework::global_transfer_data_cache() are thread_local, - // their pointer should be different among different thread id. - PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num); - PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num); +#ifdef PADDLE_WITH_MKLDNN + if (is_static) { + // Use static method to do transfer_scope_cache() instead of thread_local + // so paddle::framework::global_transfer_data_cache() should be 1 + PADDLE_ENFORCE(global_transfer_scope_cache.size(), 1); + PADDLE_ENFORCE(global_transfer_data_cache.size(), 1); + } else { +#endif + // Since paddle::framework::global_transfer_scope_cache() and + // paddle::framework::global_transfer_data_cache() are thread_local, + // their pointer should be different among different thread id. + PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num); + PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num); +#ifdef PADDLE_WITH_MKLDNN + } +#endif } +TEST(Analyzer_bert, threadlocal_transfer_scope_cache) { + verify_transfer_scope_cache(); +} +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_bert, static_transfer_scope_cache) { + verify_transfer_scope_cache(true); +} +#endif } // namespace inference } // namespace paddle