|
|
|
@ -32,9 +32,11 @@ namespace jit {
|
|
|
|
|
#define SIGMOID_THRESHOLD_MAX 13.0
|
|
|
|
|
#define EXP_MAX_INPUT 40.0
|
|
|
|
|
|
|
|
|
|
template <KernelType KT, typename T, typename Func, typename Attr,
|
|
|
|
|
typename PlaceType>
|
|
|
|
|
inline Func GetJitCode(Attr attr) {
|
|
|
|
|
template <KernelType KT, typename KernelTuples, typename PlaceType>
|
|
|
|
|
inline typename KernelTuples::func_type GetJitCode(
|
|
|
|
|
typename KernelTuples::attr_type attr) {
|
|
|
|
|
using Func = typename KernelTuples::func_type;
|
|
|
|
|
using Attr = typename KernelTuples::attr_type;
|
|
|
|
|
size_t key = JitCodeKey<Attr>(attr);
|
|
|
|
|
auto& codes = JitCodePool<KT>().Instance();
|
|
|
|
|
if (codes.Has(key)) {
|
|
|
|
@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) {
|
|
|
|
|
|
|
|
|
|
// Refer code do not related with attr, which is just for cast
|
|
|
|
|
// Refer is always on CPUPlace
|
|
|
|
|
template <KernelType KT, typename T, typename Func, typename Attr>
|
|
|
|
|
inline Func GetRefer() {
|
|
|
|
|
template <KernelType KT, typename KernelTuples>
|
|
|
|
|
inline typename KernelTuples::func_type GetRefer() {
|
|
|
|
|
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
|
|
|
|
|
KernelKey kkey(KT, platform::CPUPlace());
|
|
|
|
|
auto ref_iter = ref_pool.find(kkey);
|
|
|
|
@ -74,7 +76,7 @@ inline Func GetRefer() {
|
|
|
|
|
"Every Kernel should have reference function.");
|
|
|
|
|
auto& ref_impls = ref_iter->second;
|
|
|
|
|
for (auto& impl : ref_impls) {
|
|
|
|
|
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get());
|
|
|
|
|
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
|
|
|
|
|
if (i) {
|
|
|
|
|
return i->GetFunc();
|
|
|
|
|
}
|
|
|
|
@ -82,10 +84,10 @@ inline Func GetRefer() {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <KernelType KT, typename T, typename Func, typename Attr,
|
|
|
|
|
template <KernelType KT, typename KernelTuples,
|
|
|
|
|
typename PlaceType = platform::CPUPlace>
|
|
|
|
|
Func Get(Attr attr) {
|
|
|
|
|
auto jitfunc = GetJitCode<KT, T, Func, Attr, PlaceType>(attr);
|
|
|
|
|
typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) {
|
|
|
|
|
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
|
|
|
|
|
if (jitfunc) {
|
|
|
|
|
return jitfunc;
|
|
|
|
|
}
|
|
|
|
@ -97,7 +99,7 @@ Func Get(Attr attr) {
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i = dynamic_cast<const KernelImpl<T, Func, Attr>*>(impl.get());
|
|
|
|
|
auto i = dynamic_cast<const KernelImpl<KernelTuples>*>(impl.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
return i->GetFunc();
|
|
|
|
|
}
|
|
|
|
@ -105,7 +107,7 @@ Func Get(Attr attr) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The last implementation should be reference function on CPUPlace.
|
|
|
|
|
return GetRefer<KT, T, Func, Attr>();
|
|
|
|
|
return GetRefer<KT, KernelTuples>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace jit
|
|
|
|
|