|
|
|
@ -45,11 +45,12 @@ class LayerNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
|
|
|
|
|
int left = static_cast<int>(matrix_dim[0]);
|
|
|
|
|
int right = static_cast<int>(matrix_dim[1]);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], left);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], left);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right);
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
|
|
|
|
|
ctx->SetOutputDim("Mean", {left});
|
|
|
|
@ -143,10 +144,10 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
|
|
|
|
|
// TODO(zcd): Some thinking about output_map, is it appropriate that
|
|
|
|
|
// `output_map` and `input_map` point to the same memory.
|
|
|
|
|
auto inv_std_scale = var_map.unaryExpr(inv_std_func);
|
|
|
|
|
auto inv_std = var_map.unaryExpr(inv_std_func);
|
|
|
|
|
output_map = (input_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(inv_std_scale.replicate(1, right))
|
|
|
|
|
.cwiseProduct(scale_map.replicate(left, 1)) -
|
|
|
|
|
.cwiseProduct(inv_std.replicate(1, right))
|
|
|
|
|
.cwiseProduct(scale_map.replicate(left, 1)) +
|
|
|
|
|
bias_map.replicate(left, 1);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -230,7 +231,7 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
if (d_bias) {
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
|
|
|
|
|
d_bias_map = d_y_map.colwise().mean();
|
|
|
|
|
d_bias_map = d_y_map.colwise().sum();
|
|
|
|
|
}
|
|
|
|
|
if (d_scale) {
|
|
|
|
|
d_scale->mutable_data<T>(ctx.GetPlace());
|
|
|
|
@ -245,7 +246,7 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
var_map.unaryExpr(inv_std_func).replicate(1, right))
|
|
|
|
|
.cwiseProduct(d_y_map))
|
|
|
|
|
.colwise()
|
|
|
|
|
.mean();
|
|
|
|
|
.sum();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (d_x) {
|
|
|
|
@ -269,14 +270,14 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
.replicate(1, right);
|
|
|
|
|
// dy_var_dx
|
|
|
|
|
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(scale_map.replicate(left, 1))
|
|
|
|
|
.cwiseProduct(d_y_map)
|
|
|
|
|
.rowwise()
|
|
|
|
|
.sum();
|
|
|
|
|
auto dvar_end = var_map.unaryExpr(inv_std_func)
|
|
|
|
|
.unaryExpr(triple_product_func)
|
|
|
|
|
.cwiseProduct(dvar_end_part)
|
|
|
|
|
.replicate(1, right)
|
|
|
|
|
.cwiseProduct(scale_map.replicate(left, 1));
|
|
|
|
|
.replicate(1, right);
|
|
|
|
|
auto dx_var =
|
|
|
|
|
(T(-1.0) / right) *
|
|
|
|
|
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
|
|
|
|
|