|
|
@ -24,6 +24,20 @@ namespace jit {
|
|
|
|
namespace more {
|
|
|
|
namespace more {
|
|
|
|
namespace mkl {
|
|
|
|
namespace mkl {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
void MatMul<float>(const float* a, const float* b, float* c, int m, int n,
|
|
|
|
|
|
|
|
int k) {
|
|
|
|
|
|
|
|
platform::dynload::cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
|
|
|
|
|
|
|
|
n, k, 1.f, a, k, b, n, 0.f, c, n);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
void MatMul<double>(const double* a, const double* b, double* c, int m, int n,
|
|
|
|
|
|
|
|
int k) {
|
|
|
|
|
|
|
|
platform::dynload::cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m,
|
|
|
|
|
|
|
|
n, k, 1.0, a, k, b, n, 0.0, c, n);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
void VMul<float>(const float* x, const float* y, float* z, int n) {
|
|
|
|
void VMul<float>(const float* x, const float* y, float* z, int n) {
|
|
|
|
platform::dynload::vsMul(n, x, y, z);
|
|
|
|
platform::dynload::vsMul(n, x, y, z);
|
|
|
@ -93,6 +107,11 @@ void VAXPY<double>(double a, const double* x, double* y, int n) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
|
|
|
|
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
bool MatMulKernel<float>::UseMe(const int& d) const {
|
|
|
|
|
|
|
|
return platform::MayIUse(platform::avx);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
bool VMulKernel<float>::UseMe(const int& d) const {
|
|
|
|
bool VMulKernel<float>::UseMe(const int& d) const {
|
|
|
|
return platform::MayIUse(platform::avx512f) && d > 512;
|
|
|
|
return platform::MayIUse(platform::avx512f) && d > 512;
|
|
|
@ -139,6 +158,7 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
|
|
|
|
return true; \
|
|
|
|
return true; \
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(MatMul);
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VMul);
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VMul);
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VScal);
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VScal);
|
|
|
@ -159,6 +179,7 @@ namespace mkl = paddle::operators::jit::more::mkl;
|
|
|
|
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
|
|
|
|
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
|
|
|
|
mkl::func##Kernel<double>)
|
|
|
|
mkl::func##Kernel<double>)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_MKL_KERNEL(kMatMul, MatMul);
|
|
|
|
REGISTER_MKL_KERNEL(kVMul, VMul);
|
|
|
|
REGISTER_MKL_KERNEL(kVMul, VMul);
|
|
|
|
REGISTER_MKL_KERNEL(kVAdd, VAdd);
|
|
|
|
REGISTER_MKL_KERNEL(kVAdd, VAdd);
|
|
|
|
REGISTER_MKL_KERNEL(kVScal, VScal);
|
|
|
|
REGISTER_MKL_KERNEL(kVScal, VScal);
|
|
|
|