refine and seepdup

fix-readmd
tensor-tang 6 years ago
parent 77fc42d2d1
commit 3d928d4f9d

@ -35,29 +35,6 @@ const std::shared_ptr<Kernel> KernelPool::Get(const std::string& key) const {
return kers_.at(key);
}
#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \
template <> \
const std::shared_ptr<ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>>(int d) { \
std::string key = #ker_key #dtype_key + std::to_string(d); \
if (kers_.find(key) == kers_.end()) { \
auto p = std::make_shared<ker_class<ker_dtype>>(d); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<ker_class<ker_dtype>>(kers_.at(key)); \
}
#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \
DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \
DEFINE_WITH_DTYPE(ker_key, ker_class, double, d)
REGISTER_BLAS_JITKERNEL(vmul, VMulKernel);
REGISTER_BLAS_JITKERNEL(vadd, VAddKernel);
#undef REGISTER_BLAS_JITKERNEL
#undef DEFINE_WITH_DTYPE
template <>
const std::shared_ptr<LSTMKernel<float>>
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,

@ -40,7 +40,7 @@ typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class Kernel {
public:
Kernel() {}
Kernel() = default;
virtual ~Kernel() = default;
private:
@ -66,15 +66,13 @@ class KernelPool {
template <typename T>
class VMulKernel : public Kernel {
public:
explicit VMulKernel(int n);
void (*Compute)(const int n, const T *, const T *, T *);
virtual void Compute(const int n, const T *x, const T *y, T *z) = 0;
};
template <typename T>
class VAddKernel : public Kernel {
public:
explicit VAddKernel(int n);
void (*Compute)(const int n, const T *, const T *, T *);
virtual void Compute(const int n, const T *x, const T *y, T *z) = 0;
};
template <typename T>

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save