|
|
|
@ -108,23 +108,23 @@ struct PairForLayerNormAddFunctor {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__inline__ __device__ T rsqrt(const T val) {
|
|
|
|
|
__inline__ __device__ T rsqrt_(const T val) {
|
|
|
|
|
return static_cast<T>(1) / sqrt(val);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
__inline__ __device__ float rsqrt(const float val) {
|
|
|
|
|
__inline__ __device__ float rsqrt_(const float val) {
|
|
|
|
|
return rsqrtf(val);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
__inline__ __device__ double rsqrt(const double val) {
|
|
|
|
|
__inline__ __device__ double rsqrt_(const double val) {
|
|
|
|
|
return rsqrt(val);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
|
|
|
|
|
template <>
|
|
|
|
|
__inline__ __device__ half rsqrt(const half val) {
|
|
|
|
|
__inline__ __device__ half rsqrt_(const half val) {
|
|
|
|
|
return hrsqrt(val);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
@ -161,7 +161,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
mean_val = mean_share;
|
|
|
|
|
U invvar = rsqrt<U>(var_share + static_cast<U>(epsilon));
|
|
|
|
|
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
|
|
|
|
|
|
|
|
|
|
// Step 2: Calculate y
|
|
|
|
|
if (scale != nullptr) {
|
|
|
|
@ -204,7 +204,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(
|
|
|
|
|
const int i1 = i1_block + thr_load_row_off;
|
|
|
|
|
if (i1 >= i1_end) return;
|
|
|
|
|
U curr_mean = mean[i1];
|
|
|
|
|
U curr_invvar = rsqrt<U>(var[i1] + epsilon);
|
|
|
|
|
U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
|
|
|
|
|
for (int k = 0; k < VPT; ++k) {
|
|
|
|
|
const int i2 = i2_off + k;
|
|
|
|
|
const int load_idx = i1 * n2 + i2;
|
|
|
|
@ -352,7 +352,7 @@ __global__ void LayerNormBackwardComputeGradInput(
|
|
|
|
|
U sum_loss1 = U(0);
|
|
|
|
|
U sum_loss2 = U(0);
|
|
|
|
|
const U c_mean = mean[i1];
|
|
|
|
|
const U c_invvar = rsqrt<U>(var[i1] + epsilon);
|
|
|
|
|
const U c_invvar = rsqrt_<U>(var[i1] + epsilon);
|
|
|
|
|
const T *k_input = input + i1 * n2;
|
|
|
|
|
const T *k_dout = dout + i1 * n2;
|
|
|
|
|
constexpr int numx = BDIMX * BDIMY;
|
|
|
|
|