Merge pull request #14610 from Superjomn/revert/cache_fix

Revert "fix transfer cache thread_local bug (#14581)"
local_add_cudnn_lstm
Tao Luo 6 years ago committed by GitHub
commit e8ef14d2a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,28 +17,16 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Holds all the transfer scope across the process.
std::unordered_map<size_t, Scope*>& global_transfer_data_cache() { std::unordered_map<size_t, Scope*>& global_transfer_data_cache() {
typedef std::unordered_map<size_t, Scope*> map_t; thread_local auto* x = new std::unordered_map<size_t, Scope*>;
thread_local std::unique_ptr<map_t> x(new map_t);
return *x; return *x;
} }
// Holds all the transfer scope for this thread.
std::unordered_set<Scope*>& global_transfer_scope_cache() { std::unordered_set<Scope*>& global_transfer_scope_cache() {
typedef std::unordered_set<Scope*> set_t; thread_local auto* x = new std::unordered_set<Scope*>;
thread_local std::unique_ptr<set_t> x(new set_t);
return *x; return *x;
} }
// Try to create a transfer scope. If one cached scope has match the
// requirement, just return that one.
// Inputs:
// @type0: the source kernel type.
// @type1: the target kernel type.
// @scope: the execution scope of this op.
// Returns: A scope used to hold the transfer data across the different kernel
// type.
Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
const Scope* scope) { const Scope* scope) {
Scope* new_scope{nullptr}; Scope* new_scope{nullptr};
@ -58,5 +46,27 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
return new_scope; return new_scope;
} }
void RemoveKidsFromTransferScopeCache(Scope* scope) {
auto it = global_transfer_scope_cache().find(scope);
if (it != global_transfer_scope_cache().end()) {
global_transfer_scope_cache().erase(it);
}
for (auto* s : scope->kids()) {
auto it = global_transfer_scope_cache().find(s);
if (it != global_transfer_scope_cache().end()) {
global_transfer_scope_cache().erase(it);
}
}
// remove global transfer data cache
auto& cache = global_transfer_data_cache();
for (auto it = cache.begin(); it != cache.end();) {
if (it->second == scope)
it = cache.erase(it);
else
it++;
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -35,4 +35,5 @@ function(inference_analysis_test TARGET)
endif() endif()
endfunction(inference_analysis_test) endfunction(inference_analysis_test)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS reset_tensor_array paddle_inference_api) inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS reset_tensor_array paddle_inference_api)

Loading…
Cancel
Save