!13672 [MSLITE] fix layernorm grad op bug

From: @zhengjun10
Reviewed-by: @HilbertDavid,@zhanghaibo5
Signed-off-by: @HilbertDavid
pull/13672/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c28946eee2

@ -19,14 +19,15 @@
void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma,
int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) {
// var is actually 1/sqrf(var)-> var^0.5
// var is actually layer_norm forward output var
float eps = 1e-12;
const float *var_sqrt_rev = var;
for (size_t i = 0; i < param_num; ++i) {
float dgamma = 0.0f;
float dbeta = 0.0f;
for (size_t j = i; j < param_size * param_num; j += param_num) {
int norm_shift = (int)(j / block_size);
dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]);
dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]);
dbeta += dy[j];
}
dg[i] = dgamma;
@ -41,13 +42,14 @@ void LayerNormGrad(const float *x, const float *dy, const float *var, const floa
int norm_shift = (int)(j / block_size);
float dxm = x[j] - mean[norm_shift];
float dyg = dy[j] * gamma[param_shift];
sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift];
sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[norm_shift] + eps, -1.5);
sum2 += dyg;
sum3 += -2.0f * dxm;
}
for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) {
int param_shift = j % param_num;
int norm_shift = (int)(j / block_size);
float var_sqrt = var_sqrt_rev[norm_shift];
float var_sqrt = pow(var_sqrt_rev[norm_shift] + eps, -0.5);
float dx1 = dy[j] * gamma[param_shift] * var_sqrt;
float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]);
float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size);

Loading…
Cancel
Save