|
|
|
@ -19,14 +19,15 @@
|
|
|
|
|
|
|
|
|
|
void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma,
|
|
|
|
|
int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) {
|
|
|
|
|
// var is actually 1/sqrf(var)-> var^0.5
|
|
|
|
|
// var is actually layer_norm forward output var
|
|
|
|
|
float eps = 1e-12;
|
|
|
|
|
const float *var_sqrt_rev = var;
|
|
|
|
|
for (size_t i = 0; i < param_num; ++i) {
|
|
|
|
|
float dgamma = 0.0f;
|
|
|
|
|
float dbeta = 0.0f;
|
|
|
|
|
for (size_t j = i; j < param_size * param_num; j += param_num) {
|
|
|
|
|
int norm_shift = (int)(j / block_size);
|
|
|
|
|
dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]);
|
|
|
|
|
dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]);
|
|
|
|
|
dbeta += dy[j];
|
|
|
|
|
}
|
|
|
|
|
dg[i] = dgamma;
|
|
|
|
@ -41,13 +42,14 @@ void LayerNormGrad(const float *x, const float *dy, const float *var, const floa
|
|
|
|
|
int norm_shift = (int)(j / block_size);
|
|
|
|
|
float dxm = x[j] - mean[norm_shift];
|
|
|
|
|
float dyg = dy[j] * gamma[param_shift];
|
|
|
|
|
sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift];
|
|
|
|
|
sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[norm_shift] + eps, -1.5);
|
|
|
|
|
sum2 += dyg;
|
|
|
|
|
sum3 += -2.0f * dxm;
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) {
|
|
|
|
|
int param_shift = j % param_num;
|
|
|
|
|
int norm_shift = (int)(j / block_size);
|
|
|
|
|
float var_sqrt = var_sqrt_rev[norm_shift];
|
|
|
|
|
float var_sqrt = pow(var_sqrt_rev[norm_shift] + eps, -0.5);
|
|
|
|
|
float dx1 = dy[j] * gamma[param_shift] * var_sqrt;
|
|
|
|
|
float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]);
|
|
|
|
|
float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size);
|
|
|
|
|