!3707 Gpu support LayerNorm kernel

Merge pull request !3707 from chenweifeng/LayerNorm
pull/3707/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 773f8e5352

@ -63,10 +63,6 @@ inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) {
template <typename T>
inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *dg, T *db, T *dg_addr,
T *db_addr) {
if (threadIdx.x >= row_dim) {
return;
}
// load data to share memory
// thread(0, 32, 64, 96, ...) keep the data
DynamicSharedMem<T> share_mem;
@ -167,10 +163,6 @@ inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) {
template <typename T>
inline __device__ void InputBlockReduce(const int &col_dim, T *sum1, T *sum2, T *sum3, T *share_mem) {
if (threadIdx.x >= col_dim) {
return;
}
// load data to share memory
// thread(0, 32, 64, 96, ...) keep the data
if (threadIdx.x % WARP_SIZE == 0) {
@ -218,8 +210,8 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int&
half v2 = x[pos] - mean[row];
half v3 = my_pow(var[row] + epsilon, -0.5);
dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 +
(__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\
* __float2half(1.0 / col_dim);
(__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2]) *
__float2half(1.0 / col_dim);
}
}
@ -241,14 +233,14 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
template <typename T>
void LayerNormGrad(const int &row_dim, const int &col_dim, const int &param_dim, const T &epsilon, const T *dy,
const T *x, const T *mean, const T *var, const T *gamma, T *dx, T *dg, T *db, cudaStream_t stream) {
int share_mem_size =
((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T);
InputPropKernel<<<row_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var,
gamma, dx);
share_mem_size =
((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T);
GammaAndBetaPropKernel<<<col_dim, 256, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
const int thread_per_block = 256;
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T);
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
mean, var, gamma, dx);
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T);
GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x, mean,
var, dg, db);
}
template void LayerNormGrad(const int &row_dim, const int &col_dim, const int &param_dim, const float &epsilon,

@ -73,10 +73,6 @@ inline __device__ void WarpReduce(T *mean, T *var, T *num) {
template <typename T>
inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr,
T *share_mem) {
if (threadIdx.x >= col_dim) {
return;
}
// load data to share memory
// thread(0, 32, 64, 96, ...) keep the data
if (threadIdx.x % WARP_SIZE == 0) {
@ -146,13 +142,11 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int
template <typename T>
void LayerNorm(const int &row_dim, const int &col_dim, const int &param_dim, const T &epsilon, const T *x,
const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) {
const dim3 block(row_dim);
const dim3 thread(256);
const int thread_per_block = 256;
// keep the mean/var/num after warp reduce
int share_mem_size =
((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T);
LayerNormKernel<<<block, thread, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y,
mean, var);
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T);
LayerNormKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma,
beta, y, mean, var);
}
template void LayerNorm(const int &row_dim, const int &col_dim, const int &param_dim, const float &epsilon,

@ -141,3 +141,55 @@ def test_layernormgrad2():
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgrad3():
begin_norm_axis = -1
begin_params_axis = -1
x_np = np.random.randn(32, 64).astype(np.float32)
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgrad4():
begin_norm_axis = -1
begin_params_axis = -1
x_np = np.random.randn(32, 64).astype(np.float32)
dy_np = np.random.randn(32, 64).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
epsilon = 10e-12
dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
begin_params_axis)
dy_ms = Tensor(dy_np)
x_ms = Tensor(x_np)
var_ms = Tensor(var_np)
mean_ms = Tensor(mean_np)
gamma_ms = Tensor(gamma_np)
net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)

@ -133,3 +133,45 @@ def test_layernorm3d_2():
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernorm2d_2():
begin_norm_axis = -1
begin_params_axis = 1
x_np = np.random.randn(64, 32).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernorm2d_3():
begin_norm_axis = -1
begin_params_axis = 1
x_np = np.random.randn(128, 128).astype(np.float32)
gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np)
x_ms = Tensor(x_np)
gamma_ms = Tensor(gamma_np)
beta_ms = Tensor(beta_np)
net = LayerNormNet(begin_norm_axis, begin_params_axis)
y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms)
assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6)
assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6)
assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)

Loading…
Cancel
Save