|
|
|
@ -14,7 +14,9 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#if !defined(PADDLE_WITH_ARM)
|
|
|
|
|
#include <immintrin.h>
|
|
|
|
|
#endif
|
|
|
|
|
#include <cfloat>
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <cstring>
|
|
|
|
@ -72,6 +74,8 @@ void call_gemm_batched(const framework::ExecutionContext& ctx,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if !defined(PADDLE_WITH_ARM)
|
|
|
|
|
|
|
|
|
|
#define __m256x __m256
|
|
|
|
|
|
|
|
|
|
static const unsigned int AVX_STEP_SIZE = 8;
|
|
|
|
@ -83,16 +87,25 @@ static const unsigned int AVX_CUT_LEN_MASK = 7U;
|
|
|
|
|
#define _mm256_store_px _mm256_storeu_ps
|
|
|
|
|
#define _mm256_broadcast_sx _mm256_broadcast_ss
|
|
|
|
|
|
|
|
|
|
#define _mm256_mul_pd _mm256_mul_pd
|
|
|
|
|
#define _mm256_add_pd _mm256_add_pd
|
|
|
|
|
#define _mm256_load_pd _mm256_loadu_pd
|
|
|
|
|
#define _mm256_store_pd _mm256_storeu_pd
|
|
|
|
|
#define _mm256_broadcast_sd _mm256_broadcast_sd
|
|
|
|
|
#define __m128x __m128
|
|
|
|
|
|
|
|
|
|
static const unsigned int SSE_STEP_SIZE = 2;
|
|
|
|
|
static const unsigned int SSE_CUT_LEN_MASK = 1U;
|
|
|
|
|
|
|
|
|
|
#define _mm_add_px _mm_add_ps
|
|
|
|
|
#define _mm_mul_px _mm_mul_ps
|
|
|
|
|
#define _mm_load_px _mm_loadu_ps
|
|
|
|
|
#define _mm_store_px _mm_storeu_ps
|
|
|
|
|
#define _mm_load1_px _mm_load1_ps
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void axpy(const T* x, T* y, size_t len, const T alpha) {
|
|
|
|
|
unsigned int jjj, lll;
|
|
|
|
|
jjj = lll = 0;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_AVX
|
|
|
|
|
lll = len & ~AVX_CUT_LEN_MASK;
|
|
|
|
|
__m256x mm_alpha = _mm256_broadcast_sx(&alpha);
|
|
|
|
|
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
|
|
|
|
@ -101,66 +114,55 @@ inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) {
|
|
|
|
|
_mm256_add_px(_mm256_load_px(y + jjj),
|
|
|
|
|
_mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (; jjj < len; jjj++) {
|
|
|
|
|
y[jjj] += alpha * x[jjj];
|
|
|
|
|
#elif defined(PADDLE_WITH_ARM)
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented("axpy is not supported"));
|
|
|
|
|
#else
|
|
|
|
|
lll = len & ~SSE_CUT_LEN_MASK;
|
|
|
|
|
__m128x mm_alpha = _mm_load1_px(&alpha);
|
|
|
|
|
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
|
|
|
|
|
_mm_store_px(y + jjj,
|
|
|
|
|
_mm_add_px(_mm_load_px(y + jjj),
|
|
|
|
|
_mm_mul_px(mm_alpha, _mm_load_px(x + jjj))));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline void avx_axpy(const double* x, double* y, size_t len,
|
|
|
|
|
const float alpha) {
|
|
|
|
|
unsigned int jjj, lll;
|
|
|
|
|
jjj = lll = 0;
|
|
|
|
|
|
|
|
|
|
lll = len & ~AVX_CUT_LEN_MASK;
|
|
|
|
|
double alpha_d = static_cast<double>(alpha);
|
|
|
|
|
|
|
|
|
|
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
|
|
|
|
|
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
|
|
|
|
|
_mm256_store_pd(
|
|
|
|
|
y + jjj,
|
|
|
|
|
_mm256_add_pd(_mm256_load_pd(y + jjj),
|
|
|
|
|
_mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj))));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
for (; jjj < len; jjj++) {
|
|
|
|
|
y[jjj] += alpha * x[jjj];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
inline void avx_axpy_noadd(const double* x, double* y, size_t len,
|
|
|
|
|
const float alpha) {
|
|
|
|
|
unsigned int jjj, lll;
|
|
|
|
|
jjj = lll = 0;
|
|
|
|
|
double alpha_d = static_cast<double>(alpha);
|
|
|
|
|
lll = len & ~AVX_CUT_LEN_MASK;
|
|
|
|
|
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
|
|
|
|
|
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
|
|
|
|
|
_mm256_store_pd(y + jjj, _mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (; jjj < len; jjj++) {
|
|
|
|
|
y[jjj] = alpha * x[jjj];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
inline void avx_axpy_noadd(const float* x, float* y, size_t len,
|
|
|
|
|
const float alpha) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void axpy_noadd(const T* x, T* y, size_t len, const T alpha) {
|
|
|
|
|
unsigned int jjj, lll;
|
|
|
|
|
jjj = lll = 0;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_AVX
|
|
|
|
|
lll = len & ~AVX_CUT_LEN_MASK;
|
|
|
|
|
__m256x mm_alpha = _mm256_broadcast_sx(&alpha);
|
|
|
|
|
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
|
|
|
|
|
_mm256_store_px(y + jjj, _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj)));
|
|
|
|
|
}
|
|
|
|
|
#elif defined(PADDLE_WITH_ARM)
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented("axpy_noadd is not supported"));
|
|
|
|
|
#else
|
|
|
|
|
lll = len & ~SSE_CUT_LEN_MASK;
|
|
|
|
|
__m128x mm_alpha = _mm_load1_px(&alpha);
|
|
|
|
|
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
|
|
|
|
|
_mm_store_px(y + jjj, _mm_mul_px(mm_alpha, _mm_load_px(x + jjj)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
for (; jjj < len; jjj++) {
|
|
|
|
|
y[jjj] = alpha * x[jjj];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
inline void avx_axpy_noadd(const int8_t* x, int8_t* y, size_t len,
|
|
|
|
|
const float alpha) {
|
|
|
|
|
|
|
|
|
|
inline void axpy_noadd(const int8_t* x, int8_t* y, size_t len,
|
|
|
|
|
const float alpha) {
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"int8_t input of avx_axpy_noadd is not supported"));
|
|
|
|
|
"int8_t input of axpy_noadd is not supported"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|