|
|
|
@ -45,6 +45,37 @@ inline static int GetDesiredBlockDim(int block_dim) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \
|
|
|
|
|
log2_block_dim, feature_size, kMaxBlockNum, ...) \
|
|
|
|
|
case (1 << (log2_block_dim)): { \
|
|
|
|
|
for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \
|
|
|
|
|
int col_offset = i * kMaxBlockNum; \
|
|
|
|
|
int block_num = std::min(feature_size - col_offset, kMaxBlockNum); \
|
|
|
|
|
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
|
|
|
|
|
__VA_ARGS__; \
|
|
|
|
|
} \
|
|
|
|
|
} break
|
|
|
|
|
|
|
|
|
|
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__); \
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum, \
|
|
|
|
|
##__VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
|
|
|
|
|
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }
|
|
|
|
|
|
|
|
|
@ -131,12 +162,13 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
|
|
|
|
|
T *d_scale, T *d_bias, T *d_x,
|
|
|
|
|
const T *mean, const T *var,
|
|
|
|
|
const T *scale, float epsilon,
|
|
|
|
|
int batch_size, int feature_size) {
|
|
|
|
|
int batch_size, int feature_size,
|
|
|
|
|
int col_offset) {
|
|
|
|
|
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
|
|
|
|
|
int beg_idx = threadIdx.x * feature_size + blockIdx.x;
|
|
|
|
|
int end_idx = batch_size * feature_size + blockIdx.x;
|
|
|
|
|
int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
|
|
|
|
|
int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
|
|
|
|
|
int stride = BlockDim * feature_size;
|
|
|
|
|
|
|
|
|
|
T d_scale_partial = 0, d_bias_partial = 0;
|
|
|
|
@ -147,7 +179,7 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
|
|
|
|
|
d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
|
|
|
|
|
d_bias_partial += d_y[i];
|
|
|
|
|
if (HasDx) {
|
|
|
|
|
d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
|
|
|
|
|
d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -156,8 +188,8 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
|
|
|
|
|
PairForLayerNormAddFunctor<T>());
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
d_scale[blockIdx.x] = pair.first_;
|
|
|
|
|
d_bias[blockIdx.x] = pair.second_;
|
|
|
|
|
d_scale[blockIdx.x + col_offset] = pair.first_;
|
|
|
|
|
d_bias[blockIdx.x + col_offset] = pair.second_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -168,11 +200,11 @@ template <typename T, int BlockDim, bool HasDx, bool HasDScale>
|
|
|
|
|
__global__ void LayerNormBackwardGradientScaleOrBias(
|
|
|
|
|
const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean,
|
|
|
|
|
const T *var, const T *scale, float epsilon, int batch_size,
|
|
|
|
|
int feature_size) {
|
|
|
|
|
int feature_size, int col_offset) {
|
|
|
|
|
using BlockReduce = cub::BlockReduce<T, BlockDim>;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
int beg_idx = threadIdx.x * feature_size + blockIdx.x;
|
|
|
|
|
int end_idx = batch_size * feature_size + blockIdx.x;
|
|
|
|
|
int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
|
|
|
|
|
int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
|
|
|
|
|
int stride = BlockDim * feature_size;
|
|
|
|
|
T d_scale_or_d_bias_partial = 0;
|
|
|
|
|
|
|
|
|
@ -187,7 +219,7 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
|
|
|
|
|
|
|
|
|
|
if (HasDx) {
|
|
|
|
|
if (scale != nullptr) {
|
|
|
|
|
d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
|
|
|
|
|
d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val;
|
|
|
|
|
} else {
|
|
|
|
|
d_x[i] = d_y[i] / var_val;
|
|
|
|
|
}
|
|
|
|
@ -199,9 +231,9 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
if (HasDScale) {
|
|
|
|
|
d_scale[blockIdx.x] = d_scale_or_d_bias_partial;
|
|
|
|
|
d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
|
|
|
|
|
} else {
|
|
|
|
|
d_bias[blockIdx.x] = d_scale_or_d_bias_partial;
|
|
|
|
|
d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -322,6 +354,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
|
|
|
|
|
T *d_bias, float epsilon, int batch_size,
|
|
|
|
|
int feature_size, cudaStream_t stream) {
|
|
|
|
|
const int kMaxBlockDim = 512;
|
|
|
|
|
const int kMaxBlockNum = 128;
|
|
|
|
|
int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
|
|
|
|
|
((d_scale != nullptr ? 1 : 0) << 1) |
|
|
|
|
|
((d_bias != nullptr ? 1 : 0));
|
|
|
|
@ -347,29 +380,33 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
|
|
|
|
|
switch (gradient_flag) {
|
|
|
|
|
case 1: // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, false,
|
|
|
|
|
false><<<feature_size, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
|
|
|
|
|
feature_size));
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
|
|
|
|
|
feature_size, kMaxBlockNum,
|
|
|
|
|
LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, false,
|
|
|
|
|
false><<<block_num, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
|
|
|
|
|
batch_size, feature_size, col_offset));
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case 2: // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, false,
|
|
|
|
|
true><<<feature_size, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
|
|
|
|
|
feature_size));
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
|
|
|
|
|
feature_size, kMaxBlockNum,
|
|
|
|
|
LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, false, true><<<block_num, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
|
|
|
|
|
batch_size, feature_size, col_offset));
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case 3: // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
|
|
|
|
|
feature_size, kMaxBlockNum,
|
|
|
|
|
LayerNormBackwardGradientAll<
|
|
|
|
|
T, kBlockDim, false><<<feature_size, kBlockDim, 0, stream>>>(
|
|
|
|
|
T, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
|
|
|
|
|
batch_size, feature_size));
|
|
|
|
|
batch_size, feature_size, col_offset));
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
|
|
|
|
@ -382,11 +419,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
|
|
|
|
|
break;
|
|
|
|
|
case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, true,
|
|
|
|
|
false><<<feature_size, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
|
|
|
|
|
feature_size));
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
|
|
|
|
|
feature_size, kMaxBlockNum,
|
|
|
|
|
LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, true, false><<<block_num, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
|
|
|
|
|
batch_size, feature_size, col_offset));
|
|
|
|
|
}
|
|
|
|
|
switch (GetDesiredBlockDim(feature_size)) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(
|
|
|
|
@ -397,11 +435,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
|
|
|
|
|
break;
|
|
|
|
|
case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, true,
|
|
|
|
|
true><<<feature_size, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
|
|
|
|
|
feature_size));
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
|
|
|
|
|
feature_size, kMaxBlockNum,
|
|
|
|
|
LayerNormBackwardGradientScaleOrBias<
|
|
|
|
|
T, kBlockDim, true, true><<<block_num, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
|
|
|
|
|
batch_size, feature_size, col_offset));
|
|
|
|
|
}
|
|
|
|
|
switch (GetDesiredBlockDim(feature_size)) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(
|
|
|
|
@ -412,11 +451,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
|
|
|
|
|
break;
|
|
|
|
|
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
|
|
|
|
|
switch (block_dim) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(
|
|
|
|
|
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
|
|
|
|
|
feature_size, kMaxBlockNum,
|
|
|
|
|
LayerNormBackwardGradientAll<
|
|
|
|
|
T, kBlockDim, true><<<feature_size, kBlockDim, 0, stream>>>(
|
|
|
|
|
T, kBlockDim, true><<<block_num, kBlockDim, 0, stream>>>(
|
|
|
|
|
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
|
|
|
|
|
batch_size, feature_size));
|
|
|
|
|
batch_size, feature_size, col_offset));
|
|
|
|
|
}
|
|
|
|
|
switch (GetDesiredBlockDim(feature_size)) {
|
|
|
|
|
FIXED_BLOCK_DIM_CASE(
|
|
|
|
@ -539,6 +579,8 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template class LayerNormDirectCUDAFunctor<float>;
|
|
|
|
|
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
|
|
|
|
|
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
|
|
|
|
|
#undef FIXED_BLOCK_DIM_CASE_BASE
|
|
|
|
|
#undef FIXED_BLOCK_DIM_CASE
|
|
|
|
|
} // namespace operators
|
|
|
|
|