|
|
|
@ -230,7 +230,7 @@ TEST(Analyzer_bert, compare_determine) {
|
|
|
|
|
inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Analyzer_bert, transfer_scope_cache) {
|
|
|
|
|
void verify_transfer_scope_cache(bool is_static = false) {
|
|
|
|
|
AnalysisConfig config;
|
|
|
|
|
SetConfig(&config);
|
|
|
|
|
|
|
|
|
@ -251,6 +251,11 @@ TEST(Analyzer_bert, transfer_scope_cache) {
|
|
|
|
|
threads.emplace_back([&, i]() {
|
|
|
|
|
std::getline(fin, line);
|
|
|
|
|
ParseLine(line, &input);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// Use static method to handle transfer_scope_cache()
|
|
|
|
|
// TODO(intel) explicit session id setting will be deprecated.
|
|
|
|
|
if (is_static) platform::set_cur_mkldnn_session_id(1);
|
|
|
|
|
#endif
|
|
|
|
|
predictor->Run(input, &output, FLAGS_batch_size);
|
|
|
|
|
global_transfer_scope_cache.insert(
|
|
|
|
|
&paddle::framework::global_transfer_scope_cache());
|
|
|
|
@ -261,12 +266,31 @@ TEST(Analyzer_bert, transfer_scope_cache) {
|
|
|
|
|
threads.clear();
|
|
|
|
|
std::vector<PaddleTensor>().swap(input);
|
|
|
|
|
}
|
|
|
|
|
// Since paddle::framework::global_transfer_scope_cache() and
|
|
|
|
|
// paddle::framework::global_transfer_data_cache() are thread_local,
|
|
|
|
|
// their pointer should be different among different thread id.
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num);
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (is_static) {
|
|
|
|
|
// Use static method to do transfer_scope_cache() instead of thread_local
|
|
|
|
|
// so paddle::framework::global_transfer_data_cache() should be 1
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_scope_cache.size(), 1);
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_data_cache.size(), 1);
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
// Since paddle::framework::global_transfer_scope_cache() and
|
|
|
|
|
// paddle::framework::global_transfer_data_cache() are thread_local,
|
|
|
|
|
// their pointer should be different among different thread id.
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num);
|
|
|
|
|
PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Analyzer_bert, threadlocal_transfer_scope_cache) {
|
|
|
|
|
verify_transfer_scope_cache();
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
TEST(Analyzer_bert, static_transfer_scope_cache) {
|
|
|
|
|
verify_transfer_scope_cache(true);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
} // namespace inference
|
|
|
|
|
} // namespace paddle
|
|
|
|
|