|
|
|
@ -14,9 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
extern "C" {
|
|
|
|
|
#include <xxhash.h>
|
|
|
|
|
}
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
|
|
|
|
|
inline typename std::enable_if<
|
|
|
|
|
std::is_same<typename KernelTuple::data_type, float>::value &&
|
|
|
|
|
std::is_same<PlaceType, platform::CPUPlace>::value,
|
|
|
|
|
typename KernelTuple::func_type>::type
|
|
|
|
|
const Kernel*>::type
|
|
|
|
|
GetJitCode(const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
using Func = typename KernelTuple::func_type;
|
|
|
|
|
using Attr = typename KernelTuple::attr_type;
|
|
|
|
|
size_t key = JitCodeKey<Attr>(attr);
|
|
|
|
|
auto& codes = JitCodePool<KernelTuple::kernel_type>().Instance();
|
|
|
|
|
auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
|
|
|
|
|
if (codes.Has(key)) {
|
|
|
|
|
return codes.AllKernels().at(key)->template getCode<Func>();
|
|
|
|
|
return codes.AllKernels().at(key).get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// creator is not related with attr, so can use KernelKey as key
|
|
|
|
|
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
|
|
|
|
|
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
|
|
|
|
|
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
|
|
|
|
|
auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
|
|
|
|
|
auto iter = creator_map.find(kkey);
|
|
|
|
|
if (iter != creator_map.end()) {
|
|
|
|
|
auto& creators = iter->second;
|
|
|
|
|
for (auto& cur : creators) {
|
|
|
|
|
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
if (i && i->CanBeUsed(attr)) {
|
|
|
|
|
auto p = i->CreateJitCode(attr);
|
|
|
|
|
if (p) {
|
|
|
|
|
auto f = p->template getCode<Func>();
|
|
|
|
|
auto res = p.get();
|
|
|
|
|
codes.Insert(key, std::move(p));
|
|
|
|
|
return f;
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
|
|
|
|
|
inline typename std::enable_if<
|
|
|
|
|
!std::is_same<typename KernelTuple::data_type, float>::value ||
|
|
|
|
|
!std::is_same<PlaceType, platform::CPUPlace>::value,
|
|
|
|
|
typename KernelTuple::func_type>::type
|
|
|
|
|
const Kernel*>::type
|
|
|
|
|
GetJitCode(const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
// Refer code do not related with attr, which is just for cast
|
|
|
|
|
// Refer is always on CPUPlace
|
|
|
|
|
template <typename KernelTuple>
|
|
|
|
|
inline typename KernelTuple::func_type GetRefer() {
|
|
|
|
|
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
|
|
|
|
|
inline const Kernel* GetReferKernel() {
|
|
|
|
|
auto& ref_pool = ReferKernelPool::Instance().AllKernels();
|
|
|
|
|
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
|
|
|
|
|
auto ref_iter = ref_pool.find(kkey);
|
|
|
|
|
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
|
|
|
|
@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
|
|
|
|
|
for (auto& impl : ref_impls) {
|
|
|
|
|
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
|
|
|
|
|
if (i) {
|
|
|
|
|
return i->GetFunc();
|
|
|
|
|
return i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
|
|
|
|
|
typename KernelTuple::func_type Get(
|
|
|
|
|
template <typename KernelTuple>
|
|
|
|
|
inline typename KernelTuple::func_type GetReferFunc() {
|
|
|
|
|
auto ker = GetReferKernel<KernelTuple>();
|
|
|
|
|
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
|
|
|
|
|
PADDLE_ENFORCE(p, "The Refer kernel should exsit");
|
|
|
|
|
return p->GetFunc();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Return all Kernels that can be used
|
|
|
|
|
template <typename KernelTuple, typename PlaceType>
|
|
|
|
|
std::vector<const Kernel*> GetAllCandidateKernels(
|
|
|
|
|
const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
auto jitfunc = GetJitCode<KernelTuple, PlaceType>(attr);
|
|
|
|
|
if (jitfunc) {
|
|
|
|
|
return jitfunc;
|
|
|
|
|
// the search order shoudl be jitcode > more > refer
|
|
|
|
|
std::vector<const Kernel*> res;
|
|
|
|
|
auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
|
|
|
|
|
if (jitker) {
|
|
|
|
|
res.emplace_back(jitker);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// pool: (KernelKey(type, place), vector<KernelPtr>)
|
|
|
|
|
// more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
|
|
|
|
|
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
|
|
|
|
|
auto& pool = KernelPool().Instance().AllKernels();
|
|
|
|
|
auto& pool = KernelPool::Instance().AllKernels();
|
|
|
|
|
auto iter = pool.find(kkey);
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|
for (auto& impl : impls) {
|
|
|
|
|
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
return i->GetFunc();
|
|
|
|
|
if (i && i->CanBeUsed(attr)) {
|
|
|
|
|
res.emplace_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The last implementation should be reference function on CPUPlace.
|
|
|
|
|
return GetRefer<KernelTuple>();
|
|
|
|
|
auto ref = GetReferKernel<KernelTuple>();
|
|
|
|
|
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
|
|
|
|
|
res.emplace_back(ref);
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
|
|
|
|
|
std::vector<std::pair<std::string, typename KernelTuple::func_type>>
|
|
|
|
|
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
using Func = typename KernelTuple::func_type;
|
|
|
|
|
auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
|
|
|
|
|
std::vector<std::pair<std::string, Func>> res;
|
|
|
|
|
for (auto k : kers) {
|
|
|
|
|
std::string name = k->ImplType();
|
|
|
|
|
if (name == "JitCode") {
|
|
|
|
|
auto i = dynamic_cast<const GenBase*>(k);
|
|
|
|
|
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
|
|
|
|
|
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
|
|
|
|
|
} else {
|
|
|
|
|
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
|
|
|
|
|
PADDLE_ENFORCE(i, "kernel cast can not fail.");
|
|
|
|
|
res.emplace_back(std::make_pair(name, i->GetFunc()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
|
|
|
|
|
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
|
|
|
|
|
const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
|
|
|
|
|
std::vector<typename KernelTuple::func_type> res;
|
|
|
|
|
for (auto& i : funcs) {
|
|
|
|
|
res.emplace_back(i.second);
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
|
|
|
|
|
typename KernelTuple::func_type GetDefaultBestFunc(
|
|
|
|
|
const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
|
|
|
|
|
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
|
|
|
|
|
// Here could do some runtime benchmark of this attr and return the best one.
|
|
|
|
|
// But yet just get the first one as the default best one,
|
|
|
|
|
// which is searched in order and tuned by offline.
|
|
|
|
|
return funcs[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename KernelTuple, typename PlaceType>
|
|
|
|
@ -134,17 +187,13 @@ class KernelFuncs {
|
|
|
|
|
// the exposed interface to use
|
|
|
|
|
typename KernelTuple::func_type At(
|
|
|
|
|
const typename KernelTuple::attr_type& attr) {
|
|
|
|
|
// XXH64: 13.8 GB/s
|
|
|
|
|
// TODO(TJ): change me, maybe not all attr change need one key, should be
|
|
|
|
|
// attrkey
|
|
|
|
|
int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0);
|
|
|
|
|
// Maybe here is not good enough, not all kernels should have jitcode
|
|
|
|
|
int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
|
|
|
|
|
if (Has(key)) {
|
|
|
|
|
return funcs_.at(key);
|
|
|
|
|
}
|
|
|
|
|
// If do not have this attr in cache,
|
|
|
|
|
// then could run some runtime benchmark of this attr and save the best one.
|
|
|
|
|
// Here just get the offline benchmarked best one.
|
|
|
|
|
auto func = Get<KernelTuple, PlaceType>(attr);
|
|
|
|
|
// If do not have this attr in cache then get the default best
|
|
|
|
|
auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
|
|
|
|
|
Insert(key, func);
|
|
|
|
|
return func;
|
|
|
|
|
}
|
|
|
|
@ -156,7 +205,6 @@ class KernelFuncs {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
|
|
|
|
|
|
|
|
|
|
void Insert(int64_t key, typename KernelTuple::func_type func) {
|
|
|
|
|
funcs_.emplace(key, func);
|
|
|
|
|
}
|
|
|
|
|