Fix float64 bug in layer norm (#30452)

built-in `rsqrt` is shadowed
revert-31068-fix_conv3d_windows
Yang Zhang 5 years ago committed by GitHub
parent 715d862868
commit 008b0a8b56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -108,23 +108,23 @@ struct PairForLayerNormAddFunctor {
};
template <typename T>
__inline__ __device__ T rsqrt(const T val) {
__inline__ __device__ T rsqrt_(const T val) {
return static_cast<T>(1) / sqrt(val);
}
template <>
__inline__ __device__ float rsqrt(const float val) {
__inline__ __device__ float rsqrt_(const float val) {
return rsqrtf(val);
}
template <>
__inline__ __device__ double rsqrt(const double val) {
__inline__ __device__ double rsqrt_(const double val) {
return rsqrt(val);
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <>
__inline__ __device__ half rsqrt(const half val) {
__inline__ __device__ half rsqrt_(const half val) {
return hrsqrt(val);
}
#endif
@ -161,7 +161,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
__syncthreads();
mean_val = mean_share;
U invvar = rsqrt<U>(var_share + static_cast<U>(epsilon));
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
// Step 2: Calculate y
if (scale != nullptr) {
@ -204,7 +204,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(
const int i1 = i1_block + thr_load_row_off;
if (i1 >= i1_end) return;
U curr_mean = mean[i1];
U curr_invvar = rsqrt<U>(var[i1] + epsilon);
U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
for (int k = 0; k < VPT; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1 * n2 + i2;
@ -352,7 +352,7 @@ __global__ void LayerNormBackwardComputeGradInput(
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = rsqrt<U>(var[i1] + epsilon);
const U c_invvar = rsqrt_<U>(var[i1] + epsilon);
const T *k_input = input + i1 * n2;
const T *k_dout = dout + i1 * n2;
constexpr int numx = BDIMX * BDIMY;

Loading…
Cancel
Save