|
|
|
@ -35,29 +35,6 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
|
|
|
|
|
return kers_.at(key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
|
|
|
|
|
template <> \
|
|
|
|
|
const std::shared_ptr<ker_class<ker_dtype>> \
|
|
|
|
|
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
|
|
|
|
|
std::string key = #ker_key #dtype_key + std::to_string(d); \
|
|
|
|
|
if (kers_.find(key) == kers_.end()) { \
|
|
|
|
|
auto p = std::make_shared<ker_class<ker_dtype>>(d); \
|
|
|
|
|
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
|
|
|
|
|
return p; \
|
|
|
|
|
} \
|
|
|
|
|
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
|
|
|
|
|
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
|
|
|
|
|
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
|
|
|
|
|
|
|
|
|
|
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
|
|
|
|
|
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
|
|
|
|
|
|
|
|
|
|
#undef REGISTER_BLAS_JITKERNEL
|
|
|
|
|
#undef DEFINE_WITH_DTYPE
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const std::shared_ptr<LSTMKernel<float>>
|
|
|
|
|
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
|
|
|
|
|