You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/math/SIMDFunctions.cpp

391 lines
10 KiB

/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "SIMDFunctions.h"
#include <immintrin.h>
#include <algorithm>
#ifndef __AVX__
static void addto_sse(float* a, const float* b, size_t len) {
int offset = len % 16;
__m128 ma0, ma1, ma2, ma3;
__m128 mb0, mb1, mb2, mb3;
for (unsigned int k = 0; k < len / 16; k++, a += 16, b += 16) {
ma0 = _mm_load_ps(a);
ma1 = _mm_load_ps(a + 4);
ma2 = _mm_load_ps(a + 8);
ma3 = _mm_load_ps(a + 12);
mb0 = _mm_load_ps(b);
mb1 = _mm_load_ps(b + 4);
mb2 = _mm_load_ps(b + 8);
mb3 = _mm_load_ps(b + 12);
ma0 = _mm_add_ps(ma0, mb0);
ma1 = _mm_add_ps(ma1, mb1);
ma2 = _mm_add_ps(ma2, mb2);
ma3 = _mm_add_ps(ma3, mb3);
_mm_store_ps(a, ma0);
_mm_store_ps(a + 4, ma1);
_mm_store_ps(a + 8, ma2);
_mm_store_ps(a + 12, ma3);
}
for (int i = 0; i < offset; i++) a[i] += b[i];
}
static void batch_addto_sse(float* a, const float* b[], int batch, size_t len) {
int offset = len % 16;
__m128 ma0, ma1, ma2, ma3;
__m128 mb0, mb1, mb2, mb3;
for (unsigned int k = 0; k < len / 16; k++, a += 16) {
ma0 = _mm_load_ps(a);
ma1 = _mm_load_ps(a + 4);
ma2 = _mm_load_ps(a + 8);
ma3 = _mm_load_ps(a + 12);
for (int i = 0; i < batch; i++) {
mb0 = _mm_load_ps(b[i]);
mb1 = _mm_load_ps(b[i] + 4);
mb2 = _mm_load_ps(b[i] + 8);
mb3 = _mm_load_ps(b[i] + 12);
ma0 = _mm_add_ps(ma0, mb0);
ma1 = _mm_add_ps(ma1, mb1);
ma2 = _mm_add_ps(ma2, mb2);
ma3 = _mm_add_ps(ma3, mb3);
b[i] += 16;
}
_mm_store_ps(a, ma0);
_mm_store_ps(a + 4, ma1);
_mm_store_ps(a + 8, ma2);
_mm_store_ps(a + 12, ma3);
}
for (int i = 0; i < offset; i++) {
for (int k = 0; k < batch; k++) a[i] += b[k][i];
}
return;
}
static void col_max_sse(float* result, const float* data, int dim,
int numSamples) {
// first sample, direct copy
for (int d = 0; d < dim; ++d) {
result[d] = data[d];
}
int offset = dim % 16;
__m128 ma0, ma1, ma2, ma3;
__m128 mb0, mb1, mb2, mb3;
// first 16n dims
for (int k = 0; k < dim / 16; k++, result += 16, data += 16) {
ma0 = _mm_load_ps(result);
ma1 = _mm_load_ps(result + 4);
ma2 = _mm_load_ps(result + 8);
ma3 = _mm_load_ps(result + 12);
for (int i = 1; i < numSamples; i++) {
mb0 = _mm_load_ps(data + i * dim);
mb1 = _mm_load_ps(data + i * dim + 4);
mb2 = _mm_load_ps(data + i * dim + 8);
mb3 = _mm_load_ps(data + i * dim + 12);
ma0 = _mm_max_ps(ma0, mb0);
ma1 = _mm_max_ps(ma1, mb1);
ma2 = _mm_max_ps(ma2, mb2);
ma3 = _mm_max_ps(ma3, mb3);
}
_mm_store_ps(result, ma0);
_mm_store_ps(result + 4, ma1);
_mm_store_ps(result + 8, ma2);
_mm_store_ps(result + 12, ma3);
}
// last dims
for (int d = 0; d < offset; ++d) {
float sm = data[d];
for (int i = 1; i < numSamples; ++i) {
sm = std::max(sm, data[i * dim + d]);
}
result[d] = sm;
}
}
#else
static void addto_avx(float* a, const float* b, size_t len) {
int offset = len % 32;
__m256 ma0, ma1, ma2, ma3;
__m256 mb0, mb1, mb2, mb3;
for (unsigned int k = 0; k < len / 32; k++, a += 32, b += 32) {
ma0 = _mm256_load_ps(a);
ma1 = _mm256_load_ps(a + 8);
ma2 = _mm256_load_ps(a + 16);
ma3 = _mm256_load_ps(a + 24);
mb0 = _mm256_load_ps(b);
mb1 = _mm256_load_ps(b + 8);
mb2 = _mm256_load_ps(b + 16);
mb3 = _mm256_load_ps(b + 24);
ma0 = _mm256_add_ps(ma0, mb0);
ma1 = _mm256_add_ps(ma1, mb1);
ma2 = _mm256_add_ps(ma2, mb2);
ma3 = _mm256_add_ps(ma3, mb3);
_mm256_store_ps(a, ma0);
_mm256_store_ps(a + 8, ma1);
_mm256_store_ps(a + 16, ma2);
_mm256_store_ps(a + 24, ma3);
}
for (int i = 0; i < offset; i++) a[i] += b[i];
return;
}
static void batch_addto_avx(float* a, const float* b[], int batch, size_t len) {
int offset = len % 32;
__m256 ma0, ma1, ma2, ma3;
__m256 mb0, mb1, mb2, mb3;
for (unsigned int k = 0; k < len / 32; k++, a += 32) {
ma0 = _mm256_load_ps(a);
ma1 = _mm256_load_ps(a + 8);
ma2 = _mm256_load_ps(a + 16);
ma3 = _mm256_load_ps(a + 24);
for (int i = 0; i < batch; i++) {
mb0 = _mm256_load_ps(b[i]);
mb1 = _mm256_load_ps(b[i] + 8);
mb2 = _mm256_load_ps(b[i] + 16);
mb3 = _mm256_load_ps(b[i] + 24);
ma0 = _mm256_add_ps(ma0, mb0);
ma1 = _mm256_add_ps(ma1, mb1);
ma2 = _mm256_add_ps(ma2, mb2);
ma3 = _mm256_add_ps(ma3, mb3);
b[i] += 32;
}
_mm256_store_ps(a, ma0);
_mm256_store_ps(a + 8, ma1);
_mm256_store_ps(a + 16, ma2);
_mm256_store_ps(a + 24, ma3);
}
for (int i = 0; i < offset; i++) {
for (int k = 0; k < batch; k++) a[i] += b[k][i];
}
return;
}
static void col_max_avx(float* result, const float* data, int dim,
int numSamples) {
// first sample, direct copy
for (int d = 0; d < dim; ++d) {
result[d] = data[d];
}
int offset = dim % 32;
__m256 ma0, ma1, ma2, ma3;
__m256 mb0, mb1, mb2, mb3;
// first 16n dims
for (int k = 0; k < dim / 32; k++, result += 32, data += 32) {
ma0 = _mm256_load_ps(result);
ma1 = _mm256_load_ps(result + 8);
ma2 = _mm256_load_ps(result + 16);
ma3 = _mm256_load_ps(result + 24);
for (int i = 1; i < numSamples; i++) {
mb0 = _mm256_load_ps(data + i * dim);
mb1 = _mm256_load_ps(data + i * dim + 8);
mb2 = _mm256_load_ps(data + i * dim + 16);
mb3 = _mm256_load_ps(data + i * dim + 24);
ma0 = _mm256_max_ps(ma0, mb0);
ma1 = _mm256_max_ps(ma1, mb1);
ma2 = _mm256_max_ps(ma2, mb2);
ma3 = _mm256_max_ps(ma3, mb3);
}
_mm256_store_ps(result, ma0);
_mm256_store_ps(result + 8, ma1);
_mm256_store_ps(result + 16, ma2);
_mm256_store_ps(result + 24, ma3);
}
// last dims
for (int d = 0; d < offset; ++d) {
float sm = data[d];
for (int i = 1; i < numSamples; ++i) {
sm = std::max(sm, data[i * dim + d]);
}
result[d] = sm;
}
}
static void decayL1_avx(float* dst, float* src, float lambda, size_t sz) {
int64_t i;
int64_t size = sz;
float src_val;
__m256 ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8;
// __m256 ymm9, ymm10;
ymm1 = _mm256_set1_ps(lambda);
ymm2 = _mm256_setzero_ps();
for (i = 0; i <= size - 16; i += 16) {
ymm3 = _mm256_load_ps(src + i);
ymm6 = _mm256_load_ps(src + i + 8);
ymm4 = _mm256_sub_ps(ymm3, ymm1);
ymm7 = _mm256_sub_ps(ymm6, ymm1);
ymm5 = _mm256_add_ps(ymm3, ymm1);
ymm8 = _mm256_add_ps(ymm6, ymm1);
ymm4 = _mm256_max_ps(ymm4, ymm2);
ymm7 = _mm256_max_ps(ymm7, ymm2);
ymm5 = _mm256_min_ps(ymm5, ymm2);
ymm8 = _mm256_min_ps(ymm8, ymm2);
ymm5 = _mm256_or_ps(ymm4, ymm5);
ymm8 = _mm256_or_ps(ymm7, ymm8);
_mm256_store_ps(dst + i, ymm5);
_mm256_store_ps(dst + i + 8, ymm8);
}
if (i <= size - 8) {
ymm3 = _mm256_load_ps(src + i);
ymm4 = _mm256_sub_ps(ymm3, ymm1);
ymm5 = _mm256_add_ps(ymm3, ymm1);
ymm4 = _mm256_max_ps(ymm4, ymm2);
ymm5 = _mm256_min_ps(ymm5, ymm2);
ymm5 = _mm256_or_ps(ymm4, ymm5);
_mm256_store_ps(dst + i, ymm5);
i += 8;
}
for (; i < size; i++) {
src_val = src[i];
if (src_val > 0) {
dst[i] = ((src_val > lambda) ? (src_val - lambda) : 0);
} else {
dst[i] = ((-src_val > lambda) ? (src_val + lambda) : 0);
}
}
}
static void decayL1_avx(float* dst, float* src, float* lr, float lambda,
size_t sz) {
int64_t i;
int64_t size = sz;
float src_val;
__m256 ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8;
__m256 ymm9, ymm10;
ymm1 = _mm256_set1_ps(lambda);
ymm2 = _mm256_setzero_ps();
for (i = 0; i <= size - 16; i += 16) {
ymm9 = _mm256_load_ps(lr + i);
ymm10 = _mm256_load_ps(lr + i + 8);
ymm3 = _mm256_load_ps(src + i);
ymm6 = _mm256_load_ps(src + i + 8);
ymm9 = _mm256_mul_ps(ymm9, ymm1);
ymm10 = _mm256_mul_ps(ymm10, ymm1);
ymm4 = _mm256_sub_ps(ymm3, ymm9);
ymm7 = _mm256_sub_ps(ymm6, ymm10);
ymm5 = _mm256_add_ps(ymm3, ymm9);
ymm8 = _mm256_add_ps(ymm6, ymm10);
ymm4 = _mm256_max_ps(ymm4, ymm2);
ymm7 = _mm256_max_ps(ymm7, ymm2);
ymm5 = _mm256_min_ps(ymm5, ymm2);
ymm8 = _mm256_min_ps(ymm8, ymm2);
ymm5 = _mm256_or_ps(ymm4, ymm5);
ymm8 = _mm256_or_ps(ymm7, ymm8);
_mm256_store_ps(dst + i, ymm5);
_mm256_store_ps(dst + i + 8, ymm8);
}
if (i <= size - 8) {
ymm3 = _mm256_load_ps(src + i);
ymm9 = _mm256_load_ps(lr + i);
ymm9 = _mm256_mul_ps(ymm9, ymm1);
ymm4 = _mm256_sub_ps(ymm3, ymm9);
ymm5 = _mm256_add_ps(ymm3, ymm9);
ymm4 = _mm256_max_ps(ymm4, ymm2);
ymm5 = _mm256_min_ps(ymm5, ymm2);
ymm5 = _mm256_or_ps(ymm4, ymm5);
_mm256_store_ps(dst + i, ymm5);
i += 8;
}
for (; i < size; i++) {
src_val = src[i];
float nlambda = lr[i] * lambda;
if (src_val > 0) {
dst[i] = ((src_val > nlambda) ? (src_val - nlambda) : 0);
} else {
dst[i] = ((-src_val > nlambda) ? (src_val + nlambda) : 0);
}
}
}
#endif
#ifndef __AVX__
#define SIMD_INVOKE(func, ...) func##_sse(__VA_ARGS__)
#else
#define SIMD_INVOKE(func, ...) func##_avx(__VA_ARGS__)
#endif
namespace paddle {
namespace simd {
namespace internal {
void addToImpl(float* a, const float* b, size_t len) {
SIMD_INVOKE(addto, a, b, len);
}
void batchAddToImpl(float* a, const float* b[], int batch, size_t len) {
SIMD_INVOKE(batch_addto, a, b, batch, len);
}
void colMaxImpl(float* result, const float* data, int dim, int numSamples) {
SIMD_INVOKE(col_max, result, data, dim, numSamples);
}
#ifdef __AVX__
void decayL1AvxImpl(float* dst, float* src, float lambda, size_t len) {
decayL1_avx(dst, src, lambda, len);
}
void decayL1AvxImpl(float* dst, float* src, float* lr, float lambda,
size_t len) {
decayL1_avx(dst, src, lr, lambda, len);
}
#endif
} // namespace internal
} // namespace simd
} // namespace paddle