|
|
|
@ -23,7 +23,7 @@ constexpr int NUM_PER_THREAD_REDUCE = 4;
|
|
|
|
|
constexpr int WARP_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void MeanAndVarAccumulation(T* mean, T* var, T* num, const T& val) {
|
|
|
|
|
inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) {
|
|
|
|
|
// Welford Algorithm:
|
|
|
|
|
// \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k
|
|
|
|
|
// \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k)
|
|
|
|
@ -34,7 +34,7 @@ inline __device__ void MeanAndVarAccumulation(T* mean, T* var, T* num, const T&
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void MeanAndVarMerge(T* m1, T* v1, T* n1, const T& m2, const T& v2, const T& n2) {
|
|
|
|
|
inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) {
|
|
|
|
|
if (n2 == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -46,7 +46,7 @@ inline __device__ void MeanAndVarMerge(T* m1, T* v1, T* n1, const T& m2, const T
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void ThreadReduce(const int& col_dim, const T* block_addr, T* mean, T* var, T* num) {
|
|
|
|
|
inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) {
|
|
|
|
|
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
|
|
|
|
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
|
|
|
|
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
|
|
|
@ -60,7 +60,7 @@ inline __device__ void ThreadReduce(const int& col_dim, const T* block_addr, T*
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void WarpReduce(T* mean, T* var, T* num) {
|
|
|
|
|
inline __device__ void WarpReduce(T *mean, T *var, T *num) {
|
|
|
|
|
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
|
|
|
|
T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta);
|
|
|
|
|
T var_other = __shfl_down_sync(0xffffffff, var[0], delta);
|
|
|
|
@ -70,8 +70,8 @@ inline __device__ void WarpReduce(T* mean, T* var, T* num) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void BlockReduce(const int& col_dim, T* mean, T* var, T* num, T* mean_addr, T* var_addr,
|
|
|
|
|
T* share_mem) {
|
|
|
|
|
inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr,
|
|
|
|
|
T *share_mem) {
|
|
|
|
|
if (threadIdx.x >= col_dim) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -96,15 +96,15 @@ inline __device__ void BlockReduce(const int& col_dim, T* mean, T* var, T* num,
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
mean_addr[blockIdx.x] = share_mem[0]; // todo: blockDim.x < row
|
|
|
|
|
mean_addr[blockIdx.x] = share_mem[0];
|
|
|
|
|
share_mem[1] /= col_dim;
|
|
|
|
|
var_addr[blockIdx.x] = share_mem[1];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void LayerNorm(const int& row, const int& col_dim, const int& param_dim, const T* x,
|
|
|
|
|
const T* share_mem, const T* gamma, const T* beta, const T epsilon, T* y) {
|
|
|
|
|
inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const T *x,
|
|
|
|
|
const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) {
|
|
|
|
|
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
|
|
|
|
|
int pos = row * col_dim + col;
|
|
|
|
|
int i = pos % param_dim;
|
|
|
|
@ -113,13 +113,13 @@ inline __device__ void LayerNorm(const int& row, const int& col_dim, const int&
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* x,
|
|
|
|
|
const T* gamma, const T* beta, T* y, T* mean_addr, T* var_addr) {
|
|
|
|
|
__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x,
|
|
|
|
|
const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) {
|
|
|
|
|
for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) {
|
|
|
|
|
T mean = 0;
|
|
|
|
|
T var = 0;
|
|
|
|
|
T num = 0;
|
|
|
|
|
const T* block_addr = x + row * col_dim;
|
|
|
|
|
const T *block_addr = x + row * col_dim;
|
|
|
|
|
extern __shared__ T share_mem[];
|
|
|
|
|
|
|
|
|
|
ThreadReduce(col_dim, block_addr, &mean, &var, &num);
|
|
|
|
@ -132,8 +132,8 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* x,
|
|
|
|
|
const T* gamma, const T* beta, T* y, T* mean, T* var, cudaStream_t stream) {
|
|
|
|
|
void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x,
|
|
|
|
|
const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) {
|
|
|
|
|
const dim3 block(row_dim);
|
|
|
|
|
const dim3 thread(256);
|
|
|
|
|
// keep the mean/var/num after warp reduce
|
|
|
|
@ -143,6 +143,6 @@ void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, con
|
|
|
|
|
var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon,
|
|
|
|
|
const float* x, const float* gamma, const float* beta, float* y, float* mean, float* var,
|
|
|
|
|
template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon,
|
|
|
|
|
const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|