Fix jit tls issue (#21151)

revert-21172-masked_select_api
Yihua Xu 5 years ago committed by whs
parent a9d4eed3a8
commit eec9c9cbe7

@ -22,6 +22,12 @@ namespace paddle {
namespace operators {
namespace jit {
std::unordered_map<std::string, std::shared_ptr<void>>& GetFuncCacheMap() {
static thread_local std::unordered_map<std::string, std::shared_ptr<void>>
g_func_cache_map;
return g_func_cache_map;
}
#define ONE_CASE(key) \
case key: \
return #key

@ -15,6 +15,7 @@
#pragma once
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility> // for std::move
@ -175,13 +176,25 @@ typename KernelTuple::func_type GetDefaultBestFunc(
return funcs[0];
}
extern std::unordered_map<std::string, std::shared_ptr<void>>&
GetFuncCacheMap();
template <typename KernelTuple, typename PlaceType>
class KernelFuncs {
public:
KernelFuncs() = default;
static KernelFuncs& Cache() {
static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache;
auto& func_cache_map = GetFuncCacheMap();
std::string key = typeid(KernelFuncs<KernelTuple, PlaceType>).name();
auto iter = func_cache_map.find(key);
if (iter != func_cache_map.end()) {
return *(KernelFuncs<KernelTuple, PlaceType>*)(iter->second.get());
} else {
std::shared_ptr<void> cache =
std::make_shared<KernelFuncs<KernelTuple, PlaceType>>();
func_cache_map.emplace(key, cache);
return *(KernelFuncs<KernelTuple, PlaceType>*)(cache.get());
}
}
// the exposed interface to use

@ -21,6 +21,12 @@ namespace paddle {
namespace operators {
namespace jit {
std::unordered_map<std::string, std::shared_ptr<void>>& GetJITCodesMap() {
static thread_local std::unordered_map<std::string, std::shared_ptr<void>>
g_jit_codes_map;
return g_jit_codes_map;
}
JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool;
return g_creator_pool;

@ -28,6 +28,8 @@ namespace paddle {
namespace operators {
namespace jit {
extern std::unordered_map<std::string, std::shared_ptr<void>>& GetJITCodesMap();
template <KernelType KT>
class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr;
@ -36,8 +38,16 @@ class JitCodePool {
public:
JitCodePool() = default;
static JitCodePool& Instance() {
static thread_local JitCodePool<KT> g_jit_codes;
return g_jit_codes;
auto& jit_codes_map = GetJITCodesMap();
std::string key = typeid(JitCodePool<KT>).name();
auto iter = jit_codes_map.find(key);
if (iter != jit_codes_map.end()) {
return *(JitCodePool<KT>*)(iter->second.get());
} else {
std::shared_ptr<void> cache = std::make_shared<JitCodePool<KT>>();
jit_codes_map.emplace(key, cache);
return *(JitCodePool<KT>*)(cache.get());
}
}
const JitCodeMap& AllKernels() { return codes_; }

Loading…
Cancel
Save