fix layer_norm accuracy (#29434)

revert-31562-mean
Leo Chen 5 years ago committed by GitHub
parent 24ba9ed436
commit a040c055a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -135,7 +135,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
}
__syncthreads();
mean_val = mean[blockIdx.x];
var_val = static_cast<U>(real_sqrt(var[blockIdx.x]) + epsilon);
var_val = static_cast<U>(real_sqrt(var[blockIdx.x] + epsilon));
// Step 2: Calculate y
if (scale != nullptr) {

@ -211,7 +211,7 @@ class TestLayerNormOp(unittest.TestCase):
for name in ['x', 'scale', 'bias', 'y@GRAD']
},
fetch_list=fetch_list)
self.__assert_close(y, out[0], "y", 1e-3)
self.__assert_close(y, out[0], "y")
self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3)
self.__assert_close(x_grad, out[3], "x_grad")

Loading…
Cancel
Save