|
|
|
|
@ -27,8 +27,6 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace jitkernels {
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): rename file to kernel_pool
|
|
|
|
|
|
|
|
|
|
template <KernelType KT>
|
|
|
|
|
class JitCodePool {
|
|
|
|
|
typedef std::unique_ptr<JitBase> JitBasePtr;
|
|
|
|
|
@ -54,14 +52,6 @@ class JitCodePool {
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(JitCodePool);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): std::tuple<T, Func, Attr>
|
|
|
|
|
// template <typename T, typename Func, typename Attr>
|
|
|
|
|
// struct KernelAttr {
|
|
|
|
|
// typedef T data_type;
|
|
|
|
|
// typedef Func return_type;
|
|
|
|
|
// typedef Attr attr_type;
|
|
|
|
|
// };
|
|
|
|
|
|
|
|
|
|
typedef std::unique_ptr<const Kernel> KernelPtr;
|
|
|
|
|
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
|
|
|
|
|
KernelMap;
|
|
|
|
|
@ -120,7 +110,6 @@ inline Func GetRefer() {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): make tuple? named KernelAttr
|
|
|
|
|
template <KernelType KT, typename T, typename Func, typename Attr,
|
|
|
|
|
typename PlaceType = platform::CPUPlace>
|
|
|
|
|
const Func Get(Attr attr) {
|
|
|
|
|
@ -130,8 +119,7 @@ const Func Get(Attr attr) {
|
|
|
|
|
return codes.AllKernels().at(key)->template getCode<Func>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (std::is_same<PlaceType, platform::CPUPlace>::value) { // TODO(TJ): float
|
|
|
|
|
// move to create
|
|
|
|
|
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
|
|
|
|
|
auto p = CreateJitCode<KT, T, Attr>(attr);
|
|
|
|
|
if (p) {
|
|
|
|
|
auto f = p->template getCode<Func>();
|
|
|
|
|
|