|
|
|
@ -33,8 +33,11 @@ namespace jit {
|
|
|
|
|
#define EXP_MAX_INPUT 40.0
|
|
|
|
|
|
|
|
|
|
template <KernelType KT, typename KernelTuples, typename PlaceType>
|
|
|
|
|
inline typename KernelTuples::func_type GetJitCode(
|
|
|
|
|
typename KernelTuples::attr_type attr) {
|
|
|
|
|
inline typename std::enable_if<
|
|
|
|
|
std::is_same<typename KernelTuples::data_type, float>::value &&
|
|
|
|
|
std::is_same<PlaceType, platform::CPUPlace>::value,
|
|
|
|
|
typename KernelTuples::func_type>::type
|
|
|
|
|
GetJitCode(typename KernelTuples::attr_type attr) {
|
|
|
|
|
using Func = typename KernelTuples::func_type;
|
|
|
|
|
using Attr = typename KernelTuples::attr_type;
|
|
|
|
|
size_t key = JitCodeKey<Attr>(attr);
|
|
|
|
@ -45,21 +48,19 @@ inline typename KernelTuples::func_type GetJitCode(
|
|
|
|
|
|
|
|
|
|
// creator is not related with attr, so can use KernelKey as key
|
|
|
|
|
KernelKey kkey(KT, PlaceType());
|
|
|
|
|
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
|
|
|
|
|
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
|
|
|
|
|
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
|
|
|
|
|
auto iter = creator_map.find(kkey);
|
|
|
|
|
if (iter != creator_map.end()) {
|
|
|
|
|
auto& creators = iter->second;
|
|
|
|
|
for (auto& cur : creators) {
|
|
|
|
|
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
auto p = i->CreateJitCode(attr);
|
|
|
|
|
if (p) {
|
|
|
|
|
auto f = p->template getCode<Func>();
|
|
|
|
|
codes.Insert(key, std::move(p));
|
|
|
|
|
return f;
|
|
|
|
|
}
|
|
|
|
|
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
|
|
|
|
|
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
|
|
|
|
|
auto iter = creator_map.find(kkey);
|
|
|
|
|
if (iter != creator_map.end()) {
|
|
|
|
|
auto& creators = iter->second;
|
|
|
|
|
for (auto& cur : creators) {
|
|
|
|
|
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
auto p = i->CreateJitCode(attr);
|
|
|
|
|
if (p) {
|
|
|
|
|
auto f = p->template getCode<Func>();
|
|
|
|
|
codes.Insert(key, std::move(p));
|
|
|
|
|
return f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -67,6 +68,15 @@ inline typename KernelTuples::func_type GetJitCode(
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <KernelType KT, typename KernelTuples, typename PlaceType>
|
|
|
|
|
inline typename std::enable_if<
|
|
|
|
|
!std::is_same<typename KernelTuples::data_type, float>::value ||
|
|
|
|
|
!std::is_same<PlaceType, platform::CPUPlace>::value,
|
|
|
|
|
typename KernelTuples::func_type>::type
|
|
|
|
|
GetJitCode(typename KernelTuples::attr_type attr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Refer code do not related with attr, which is just for cast
|
|
|
|
|
// Refer is always on CPUPlace
|
|
|
|
|
template <KernelType KT, typename KernelTuples>
|
|
|
|
|