|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <memory> // for shared_ptr
|
|
|
|
|
#include <memory> // for unique_ptr
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -52,6 +52,28 @@ class JitCodePool {
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(JitCodePool);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class JitCodeCreatorPool {
|
|
|
|
|
typedef std::unique_ptr<const GenCreator> GenCreatorPtr;
|
|
|
|
|
typedef std::unordered_map<KernelKey, std::vector<GenCreatorPtr>,
|
|
|
|
|
KernelKey::Hash>
|
|
|
|
|
GenCreatorPtrMap;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
JitCodeCreatorPool() = default;
|
|
|
|
|
static JitCodeCreatorPool& Instance();
|
|
|
|
|
GenCreatorPtrMap& AllCreators() { return creators_; }
|
|
|
|
|
void Insert(const KernelKey& key, GenCreatorPtr value) {
|
|
|
|
|
if (creators_.find(key) == creators_.end()) {
|
|
|
|
|
creators_.emplace(key, std::vector<GenCreatorPtr>());
|
|
|
|
|
}
|
|
|
|
|
creators_.at(key).emplace_back(std::move(value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
GenCreatorPtrMap creators_;
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
typedef std::unique_ptr<const Kernel> KernelPtr;
|
|
|
|
|
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
|
|
|
|
|
KernelMap;
|
|
|
|
@ -113,24 +135,33 @@ inline Func GetRefer() {
|
|
|
|
|
template <KernelType KT, typename T, typename Func, typename Attr,
|
|
|
|
|
typename PlaceType = platform::CPUPlace>
|
|
|
|
|
const Func Get(Attr attr) {
|
|
|
|
|
size_t key = GetKey<Attr>(attr);
|
|
|
|
|
size_t key = JitCodeKey<Attr>(attr);
|
|
|
|
|
auto& codes = JitCodePool<KT>().Instance();
|
|
|
|
|
if (codes.Has(key)) {
|
|
|
|
|
return codes.AllKernels().at(key)->template getCode<Func>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelKey kkey(KT, PlaceType());
|
|
|
|
|
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
|
|
|
|
|
auto p = CreateJitCode<KT, T, Attr>(attr);
|
|
|
|
|
if (p) {
|
|
|
|
|
auto f = p->template getCode<Func>();
|
|
|
|
|
codes.Insert(key, std::move(p));
|
|
|
|
|
return f;
|
|
|
|
|
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
|
|
|
|
|
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
|
|
|
|
|
auto iter = creator_map.find(kkey);
|
|
|
|
|
auto& creators = iter->second;
|
|
|
|
|
for (auto& cur : creators) {
|
|
|
|
|
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
|
|
|
|
|
if (i && i->UseMe(attr)) {
|
|
|
|
|
auto p = i->CreateJitCode(attr);
|
|
|
|
|
if (p) {
|
|
|
|
|
auto f = p->template getCode<Func>();
|
|
|
|
|
codes.Insert(key, std::move(p));
|
|
|
|
|
return f;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// pool: (KernelKey(type, place), vector<Kernel>)
|
|
|
|
|
// pool: (KernelKey(type, place), vector<KernelPtr>)
|
|
|
|
|
auto& pool = KernelPool().Instance().AllKernels();
|
|
|
|
|
KernelKey kkey(KT, PlaceType());
|
|
|
|
|
auto iter = pool.find(kkey);
|
|
|
|
|
if (iter != pool.end()) {
|
|
|
|
|
auto& impls = iter->second;
|
|
|
|
|