|
|
|
@ -18,10 +18,21 @@
|
|
|
|
|
#include <stdint.h>
|
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
|
#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh"
|
|
|
|
|
#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh"
|
|
|
|
|
|
|
|
|
|
constexpr int NUM_PER_THREAD_REDUCE = 4;
|
|
|
|
|
constexpr int WARP_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ T my_pow(T a, double b) {
|
|
|
|
|
return pow(a, static_cast<float>(b));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline __device__ half my_pow(half a, double b) {
|
|
|
|
|
return __float2half(pow(__half2float(a), static_cast<float>(b)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim,
|
|
|
|
|
const T& epsilon, const T* dy, const T* x, const T* mean, const T* var,
|
|
|
|
@ -35,7 +46,7 @@ inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_d
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int pos = row * col_dim + col;
|
|
|
|
|
dg[0] += dy[pos] * pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]);
|
|
|
|
|
dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]);
|
|
|
|
|
db[0] += dy[pos];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -58,26 +69,26 @@ inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_di
|
|
|
|
|
|
|
|
|
|
// load data to share memory
|
|
|
|
|
// thread(0, 32, 64, 96, ...) keep the data
|
|
|
|
|
extern __shared__ T share_mem[];
|
|
|
|
|
DynamicSharedMem<T> share_mem;
|
|
|
|
|
if (threadIdx.x % WARP_SIZE == 0) {
|
|
|
|
|
int offset = threadIdx.x / WARP_SIZE * 2;
|
|
|
|
|
share_mem[offset] = dg[0];
|
|
|
|
|
share_mem[offset + 1] = db[0];
|
|
|
|
|
share_mem.addr()[offset] = dg[0];
|
|
|
|
|
share_mem.addr()[offset + 1] = db[0];
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
|
|
|
|
if (threadIdx.x < stride) {
|
|
|
|
|
int offset = (threadIdx.x + stride) * 2;
|
|
|
|
|
share_mem[threadIdx.x * 2] += share_mem[offset];
|
|
|
|
|
share_mem[threadIdx.x * 2 + 1] += share_mem[offset + 1];
|
|
|
|
|
share_mem.addr()[threadIdx.x * 2] += share_mem.addr()[offset];
|
|
|
|
|
share_mem.addr()[threadIdx.x * 2 + 1] += share_mem.addr()[offset + 1];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
dg_addr[col] = share_mem[0];
|
|
|
|
|
db_addr[col] = share_mem[1];
|
|
|
|
|
dg_addr[col] = share_mem.addr()[0];
|
|
|
|
|
db_addr[col] = share_mem.addr()[1];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -114,13 +125,37 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con
|
|
|
|
|
T v1 = dy[pos] * gamma[gamma_offset];
|
|
|
|
|
T v2 = x[pos] - mean[row];
|
|
|
|
|
|
|
|
|
|
sum1[0] += -0.5 * v1 * v2 * pow(var[row] + epsilon, -1.5);
|
|
|
|
|
sum1[0] += -0.5 * v1 * v2 * my_pow(var[row] + epsilon, -1.5);
|
|
|
|
|
sum2[0] += v1;
|
|
|
|
|
sum3[0] += -2.0 * v2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon,
|
|
|
|
|
half* sum1, half* sum2, half* sum3, const half* dy, const half* x,
|
|
|
|
|
const half* mean, const half* var, const half* gamma) {
|
|
|
|
|
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
|
|
|
|
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
|
|
|
|
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
|
|
|
|
int col = NUM_PER_THREAD_REDUCE * i + j;
|
|
|
|
|
if (col >= col_dim) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int pos = row * col_dim + col;
|
|
|
|
|
int gamma_offset = pos % param_dim;
|
|
|
|
|
half v1 = dy[pos] * gamma[gamma_offset];
|
|
|
|
|
half v2 = x[pos] - mean[row];
|
|
|
|
|
|
|
|
|
|
sum1[0] += __float2half(-0.5) * v1 * v2 * my_pow(var[row] + epsilon, -1.5);
|
|
|
|
|
sum2[0] += v1;
|
|
|
|
|
sum3[0] += __float2half(-2.0) * v2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) {
|
|
|
|
|
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
|
|
|
@ -166,12 +201,28 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int&
|
|
|
|
|
int gamma_offset = pos % param_dim;
|
|
|
|
|
T v1 = dy[pos] * gamma[gamma_offset];
|
|
|
|
|
T v2 = x[pos] - mean[row];
|
|
|
|
|
T v3 = pow(var[row] + epsilon, -0.5);
|
|
|
|
|
T v3 = my_pow(var[row] + epsilon, -0.5);
|
|
|
|
|
dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 +
|
|
|
|
|
(-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon,
|
|
|
|
|
const half* dy, const half* x, const half* mean, const half* var, const half* gamma,
|
|
|
|
|
half* dx, const half* share_mem) {
|
|
|
|
|
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
|
|
|
|
|
int pos = (row * col_dim + col);
|
|
|
|
|
int gamma_offset = pos % param_dim;
|
|
|
|
|
half v1 = dy[pos] * gamma[gamma_offset];
|
|
|
|
|
half v2 = x[pos] - mean[row];
|
|
|
|
|
half v3 = my_pow(var[row] + epsilon, -0.5);
|
|
|
|
|
dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 +
|
|
|
|
|
(__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\
|
|
|
|
|
* __float2half(1.0 / col_dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy,
|
|
|
|
|
const T* x, const T* mean, const T* var, const T* gamma, T* dx) {
|
|
|
|
@ -179,27 +230,30 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
|
|
|
|
|
T sum1 = 0;
|
|
|
|
|
T sum2 = 0;
|
|
|
|
|
T sum3 = 0;
|
|
|
|
|
extern __shared__ T share_mem[];
|
|
|
|
|
DynamicSharedMem<T> share_mem;
|
|
|
|
|
InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma);
|
|
|
|
|
InputWarpReduce(&sum1, &sum2, &sum3);
|
|
|
|
|
InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem);
|
|
|
|
|
InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem);
|
|
|
|
|
InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr());
|
|
|
|
|
InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy,
|
|
|
|
|
const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) {
|
|
|
|
|
int share_mem =
|
|
|
|
|
int share_mem_size =
|
|
|
|
|
((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T);
|
|
|
|
|
InputPropKernel<<<row_dim, 256, share_mem, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, gamma,
|
|
|
|
|
dx);
|
|
|
|
|
InputPropKernel<<<row_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var,
|
|
|
|
|
gamma, dx);
|
|
|
|
|
|
|
|
|
|
share_mem =
|
|
|
|
|
share_mem_size =
|
|
|
|
|
((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T);
|
|
|
|
|
GammaAndBetaPropKernel<<<col_dim, 256, share_mem, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
|
|
|
|
|
GammaAndBetaPropKernel<<<col_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon,
|
|
|
|
|
const float* dy, const float* x, const float* mean, const float* var, const float* gamma,
|
|
|
|
|
float* dx, float* dg, float* db, cudaStream_t stream);
|
|
|
|
|
template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon,
|
|
|
|
|
const half* dy, const half* x, const half* mean, const half* var, const half* gamma,
|
|
|
|
|
half* dx, half* dg, half* db, cudaStream_t stream);
|
|
|
|
|