|
|
|
@ -75,6 +75,34 @@ __device__ inline void LayerNorm(const kvp<T> &thread_data, const int ld,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename T2, int TPB>
|
|
|
|
|
__device__ inline void LayerNorm2(const kvp<T> &thread_data, const int ld,
|
|
|
|
|
const int offset, const float2 *bias,
|
|
|
|
|
const float2 *scale, T2 *output, T eps) {
|
|
|
|
|
using BlockReduce = cub::BlockReduce<kvp<T>, TPB>;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
__shared__ T mu; // mean
|
|
|
|
|
__shared__ T rsigma; // 1 / std.dev.
|
|
|
|
|
|
|
|
|
|
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
mu = sum_kv.key;
|
|
|
|
|
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int i = threadIdx.x; i < ld; i += TPB) {
|
|
|
|
|
const int idx = offset + i;
|
|
|
|
|
T2 val = output[idx];
|
|
|
|
|
const float2 g = scale[i];
|
|
|
|
|
const float2 b = bias[i];
|
|
|
|
|
val.x = T(g.x) * (val.x - mu) * rsigma + T(b.x);
|
|
|
|
|
val.y = T(g.y) * (val.y - mu) * rsigma + T(b.y);
|
|
|
|
|
output[idx] = val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, unsigned TPB>
|
|
|
|
|
__global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
|
|
|
|
|
const float *scale, const float *bias,
|
|
|
|
@ -323,6 +351,27 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
|
|
|
|
|
LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename T2, unsigned TPB>
|
|
|
|
|
__global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1,
|
|
|
|
|
const T2 *input2, T2 *output,
|
|
|
|
|
const float2 *scale, const float2 *bias,
|
|
|
|
|
float eps) {
|
|
|
|
|
const T rld = T(0.5f / hidden); // because hidden is hidden/2
|
|
|
|
|
const int offset = blockIdx.x * hidden;
|
|
|
|
|
cub::Sum pair_sum;
|
|
|
|
|
kvp<T> thread_data(0, 0);
|
|
|
|
|
|
|
|
|
|
for (int it = threadIdx.x; it < hidden; it += TPB) {
|
|
|
|
|
const int idx = offset + it;
|
|
|
|
|
const T2 val2 = input1[idx] + input2[idx];
|
|
|
|
|
thread_data = pair_sum(
|
|
|
|
|
thread_data, kvp<T>(rld * (val2.x + val2.y),
|
|
|
|
|
rld * val2.x * val2.x + rld * val2.y * val2.y));
|
|
|
|
|
output[idx] = val2;
|
|
|
|
|
}
|
|
|
|
|
LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
|
|
|
|
|
const T *input1, const T *input2,
|
|
|
|
@ -344,8 +393,35 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
|
|
|
|
|
num, hidden, input1, input2, output, scale, bias, eps);
|
|
|
|
|
} else {
|
|
|
|
|
const int threads = 256;
|
|
|
|
|
SkipLayerNormKernel<T, threads><<<block, threads, 0, stream>>>(
|
|
|
|
|
num, hidden, input1, input2, output, scale, bias, eps);
|
|
|
|
|
if (hidden % 2 == 0) {
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
if (std::is_same<T, float>::value) {
|
|
|
|
|
#endif
|
|
|
|
|
SkipLayerNormKernel2<float, float2,
|
|
|
|
|
threads><<<block, threads, 0, stream>>>(
|
|
|
|
|
num, hidden / 2, reinterpret_cast<const float2 *>(input1),
|
|
|
|
|
reinterpret_cast<const float2 *>(input2),
|
|
|
|
|
reinterpret_cast<float2 *>(output),
|
|
|
|
|
reinterpret_cast<const float2 *>(scale),
|
|
|
|
|
reinterpret_cast<const float2 *>(bias), eps);
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
} else if (std::is_same<T, __half>::value) {
|
|
|
|
|
SkipLayerNormKernel2<__half, __half2,
|
|
|
|
|
threads><<<block, threads, 0, stream>>>(
|
|
|
|
|
num, hidden / 2, reinterpret_cast<const __half2 *>(input1),
|
|
|
|
|
reinterpret_cast<const __half2 *>(input2),
|
|
|
|
|
reinterpret_cast<__half2 *>(output),
|
|
|
|
|
reinterpret_cast<const float2 *>(scale),
|
|
|
|
|
reinterpret_cast<const float2 *>(bias), eps);
|
|
|
|
|
} else {
|
|
|
|
|
assert(false);
|
|
|
|
|
// should not be here
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
SkipLayerNormKernel<T, threads><<<block, threads, 0, stream>>>(
|
|
|
|
|
num, hidden, input1, input2, output, scale, bias, eps);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|