|
|
|
@ -78,6 +78,24 @@ void VScal<double>(const double* a, const double* x, double* y, int n) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void StrideScal<float>(const float* a, const float* x, float* y, int n, int stride) {
|
|
|
|
|
if (x == y) {
|
|
|
|
|
platform::dynload::cblas_sscal(n, *a, y, stride);
|
|
|
|
|
} else {
|
|
|
|
|
refer::StrideScal<float>(a, x, y, n, stride);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void StrideScal<double>(const double* a, const double* x, double* y, int n, int stride) {
|
|
|
|
|
if (x == y) {
|
|
|
|
|
platform::dynload::cblas_dscal(n, *a, y, stride);
|
|
|
|
|
} else {
|
|
|
|
|
refer::StrideScal<double>(a, x, y, n, stride);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void VExp<float>(const float* x, float* y, int n) {
|
|
|
|
|
platform::dynload::vsExp(n, x, y);
|
|
|
|
@ -128,6 +146,16 @@ void ASum<double>(const double* x, double* res, int n) {
|
|
|
|
|
res[0] = platform::dynload::cblas_dasum(n, x, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void StrideSum<float>(const float* x, float* res, int n, int stride) {
|
|
|
|
|
res[0] = platform::dynload::cblas_sasum(n, x, stride);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void StrideSum<double>(const double* x, double* res, int n, int stride) {
|
|
|
|
|
res[0] = platform::dynload::cblas_dasum(n, x, stride);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
|
|
|
|
|
template <>
|
|
|
|
|
bool VMulKernel<float>::CanBeUsed(const int& d) const {
|
|
|
|
@ -144,6 +172,11 @@ bool VScalKernel<float>::CanBeUsed(const int& d) const {
|
|
|
|
|
return platform::MayIUse(platform::avx512f) && d > 512;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
bool StrideScalKernel<float>::CanBeUsed(const int& d) const {
|
|
|
|
|
return platform::MayIUse(platform::avx512f) && d > 512;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
bool VExpKernel<float>::CanBeUsed(const int& d) const {
|
|
|
|
|
return d > 7;
|
|
|
|
@ -235,6 +268,7 @@ bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VMul);
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VScal);
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(StrideScal);
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VExp);
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
|
|
|
|
@ -259,6 +293,7 @@ REGISTER_MKL_KERNEL(MatMul);
|
|
|
|
|
REGISTER_MKL_KERNEL(VMul);
|
|
|
|
|
REGISTER_MKL_KERNEL(VAdd);
|
|
|
|
|
REGISTER_MKL_KERNEL(VScal);
|
|
|
|
|
REGISTER_MKL_KERNEL(StrideScal);
|
|
|
|
|
REGISTER_MKL_KERNEL(VExp);
|
|
|
|
|
REGISTER_MKL_KERNEL(VSquare);
|
|
|
|
|
REGISTER_MKL_KERNEL(VCopy);
|
|
|
|
|