|
|
|
@ -63,10 +63,6 @@ inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *dg, T *db, T *dg_addr,
|
|
|
|
|
T *db_addr) {
|
|
|
|
|
if (threadIdx.x >= row_dim) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// load data to share memory
|
|
|
|
|
// thread(0, 32, 64, 96, ...) keep the data
|
|
|
|
|
DynamicSharedMem<T> share_mem;
|
|
|
|
@ -167,10 +163,6 @@ inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline __device__ void InputBlockReduce(const int &col_dim, T *sum1, T *sum2, T *sum3, T *share_mem) {
|
|
|
|
|
if (threadIdx.x >= col_dim) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// load data to share memory
|
|
|
|
|
// thread(0, 32, 64, 96, ...) keep the data
|
|
|
|
|
if (threadIdx.x % WARP_SIZE == 0) {
|
|
|
|
@ -218,8 +210,8 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int&
|
|
|
|
|
half v2 = x[pos] - mean[row];
|
|
|
|
|
half v3 = my_pow(var[row] + epsilon, -0.5);
|
|
|
|
|
dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 +
|
|
|
|
|
(__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\
|
|
|
|
|
* __float2half(1.0 / col_dim);
|
|
|
|
|
(__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2]) *
|
|
|
|
|
__float2half(1.0 / col_dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -241,14 +233,14 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
|
|
|
|
|
template <typename T>
|
|
|
|
|
void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *dy,
|
|
|
|
|
const T *x, const T *mean, const T *var, const T *gamma, T *dx, T *dg, T *db, cudaStream_t stream) {
|
|
|
|
|
int share_mem_size =
|
|
|
|
|
((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T);
|
|
|
|
|
InputPropKernel<<<row_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var,
|
|
|
|
|
gamma, dx);
|
|
|
|
|
|
|
|
|
|
share_mem_size =
|
|
|
|
|
((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T);
|
|
|
|
|
GammaAndBetaPropKernel<<<col_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
|
|
|
|
|
const int thread_per_block = 256;
|
|
|
|
|
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T);
|
|
|
|
|
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
|
|
|
|
|
mean, var, gamma, dx);
|
|
|
|
|
|
|
|
|
|
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T);
|
|
|
|
|
GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean,
|
|
|
|
|
var, dg, db);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon,
|
|
|
|
|