|
|
|
@ -233,13 +233,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
if (d_x) {
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
|
|
|
|
|
auto triple_product = [](T ele) { return ele * ele * ele; };
|
|
|
|
|
auto neg_inv_std = [](T ele) { return T(-1.0) * std::sqrt(1 / ele); };
|
|
|
|
|
auto triple_product = [](T ele) { return ele * ele; };
|
|
|
|
|
auto neg_inv_std = [](T ele) { return -std::sqrt(1 / ele); };
|
|
|
|
|
auto inv_std_scale_func = [scale_data](T ele) {
|
|
|
|
|
return std::sqrt(1 / ele) * scale_data;
|
|
|
|
|
};
|
|
|
|
|
auto neg_inv_std_scale_func = [scale_data](T ele) {
|
|
|
|
|
return T(-1.0) * std::sqrt(1 / ele) * scale_data;
|
|
|
|
|
return -std::sqrt(1 / ele) * scale_data;
|
|
|
|
|
};
|
|
|
|
|
// dy_dx
|
|
|
|
|
auto dx_end = var_map.unaryExpr(inv_std_scale_func)
|
|
|
|
@ -260,10 +260,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
auto dvar_end = var_map.unaryExpr(neg_inv_std)
|
|
|
|
|
.unaryExpr(triple_product)
|
|
|
|
|
.cwiseProduct(dvar_end_0);
|
|
|
|
|
auto dx_var = (1.0f / right) *
|
|
|
|
|
auto dx_var = (T(1.0) / right) *
|
|
|
|
|
(x_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(dvar_end.replicate(1, right));
|
|
|
|
|
|
|
|
|
|
// d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
|
|
|
|
|
// - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
|
|
|
|
|
|
|
|
|
|
d_x_map = dx_end + dx_mean + dx_var;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|