|
|
|
@ -21,6 +21,13 @@ using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using DataLayout = framework::DataLayout;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
using EigenMatrixMapRowMajor = Eigen::Map<
|
|
|
|
|
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
|
|
|
|
template <typename T>
|
|
|
|
|
using ConstEigenMatrixMapRowMajor = Eigen::Map<
|
|
|
|
|
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
|
|
|
|
|
|
|
|
|
class LayerNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -101,7 +108,6 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Layer Normalization.
|
|
|
|
|
|
|
|
|
|
Layer Norm has been implemented as discussed in the paper:
|
|
|
|
|
https://arxiv.org/abs/1607.06450
|
|
|
|
|
...
|
|
|
|
@ -109,6 +115,75 @@ https://arxiv.org/abs/1607.06450
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LayerNormKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
: public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
const float epsilon = ctx.Attr<float>("epsilon");
|
|
|
|
|
const auto *scale = ctx.Input<Tensor>("Scale");
|
|
|
|
|
const auto *bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
|
|
|
|
|
|
|
|
|
|
auto *output = ctx.Output<Tensor>("Y");
|
|
|
|
|
auto *mean = ctx.Output<Tensor>("Mean");
|
|
|
|
|
auto *var = ctx.Output<Tensor>("Variance");
|
|
|
|
|
output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
mean->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
var->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
|
|
|
|
|
int left = static_cast<int>(matrix_dim[0]);
|
|
|
|
|
int right = static_cast<int>(matrix_dim[1]);
|
|
|
|
|
|
|
|
|
|
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
|
|
|
|
|
|
|
|
|
|
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
|
|
|
|
|
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
|
|
|
|
|
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
|
|
|
|
|
|
|
|
|
|
auto squre = [](T ele) { return ele * ele; };
|
|
|
|
|
auto add_epslion = [epsilon](T ele) { return ele + epsilon; };
|
|
|
|
|
|
|
|
|
|
mean_map = input_map.rowwise().mean();
|
|
|
|
|
var_map = (input_map - mean_map.replicate(1, right))
|
|
|
|
|
.unaryExpr(squre)
|
|
|
|
|
.rowwise()
|
|
|
|
|
.mean()
|
|
|
|
|
.unaryExpr(add_epslion);
|
|
|
|
|
|
|
|
|
|
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
|
|
|
|
|
// TODO(zcd): Some thinking about output_map, is it appropriate that
|
|
|
|
|
// `output_map` and `input_map` point to the same memory.
|
|
|
|
|
auto inv_std = var_map.unaryExpr(inv_std_func);
|
|
|
|
|
if (scale && bias) {
|
|
|
|
|
auto scale_map =
|
|
|
|
|
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
|
|
|
|
|
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
|
|
|
|
|
output_map = (input_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(inv_std.replicate(1, right))
|
|
|
|
|
.cwiseProduct(scale_map.replicate(left, 1)) +
|
|
|
|
|
bias_map.replicate(left, 1);
|
|
|
|
|
} else if (scale) {
|
|
|
|
|
auto scale_map =
|
|
|
|
|
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
|
|
|
|
|
output_map = (input_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(inv_std.replicate(1, right))
|
|
|
|
|
.cwiseProduct(scale_map.replicate(left, 1));
|
|
|
|
|
} else if (bias) {
|
|
|
|
|
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
|
|
|
|
|
output_map = (input_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(inv_std.replicate(1, right)) +
|
|
|
|
|
bias_map.replicate(left, 1);
|
|
|
|
|
} else {
|
|
|
|
|
output_map = (input_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(inv_std.replicate(1, right));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class LayerNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -161,6 +236,115 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LayerNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
: public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
|
const auto *mean = ctx.Input<Tensor>("Mean");
|
|
|
|
|
const auto *var = ctx.Input<Tensor>("Variance");
|
|
|
|
|
const auto *scale = ctx.Input<Tensor>("Scale");
|
|
|
|
|
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|
|
|
|
|
|
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
|
|
|
|
|
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
|
|
|
|
|
int left = static_cast<int>(matrix_dim[0]);
|
|
|
|
|
int right = static_cast<int>(matrix_dim[1]);
|
|
|
|
|
|
|
|
|
|
// init output
|
|
|
|
|
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
|
|
|
|
|
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
|
|
|
|
|
|
|
|
|
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
|
|
|
|
|
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
|
|
|
|
|
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
|
|
|
|
|
auto var_map = ConstEigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
|
|
|
|
|
|
|
|
|
|
if (d_bias) {
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
|
|
|
|
|
d_bias_map = d_y_map.colwise().sum();
|
|
|
|
|
}
|
|
|
|
|
if (d_scale) {
|
|
|
|
|
d_scale->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_scale_map =
|
|
|
|
|
EigenMatrixMapRowMajor<T>(d_scale->data<T>(), 1, right);
|
|
|
|
|
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
|
|
|
|
|
// There are two equation to compute d_scale. One uses "Y" and the other
|
|
|
|
|
// does not use "Y"
|
|
|
|
|
d_scale_map =
|
|
|
|
|
((x_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(
|
|
|
|
|
var_map.unaryExpr(inv_std_func).replicate(1, right))
|
|
|
|
|
.cwiseProduct(d_y_map))
|
|
|
|
|
.colwise()
|
|
|
|
|
.sum();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (d_x) {
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
|
|
|
|
|
auto triple_product_func = [](T ele) { return ele * ele * ele; };
|
|
|
|
|
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
|
|
|
|
|
|
|
|
|
|
auto inv_std_map = var_map.unaryExpr(inv_std_func).eval();
|
|
|
|
|
// TODO(zcd): these code can be refined
|
|
|
|
|
if (d_scale) {
|
|
|
|
|
auto scale_map =
|
|
|
|
|
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
|
|
|
|
|
// dy_dx
|
|
|
|
|
auto dx_end =
|
|
|
|
|
inv_std_map.replicate(1, right).cwiseProduct(d_y_map).cwiseProduct(
|
|
|
|
|
scale_map.replicate(left, 1));
|
|
|
|
|
|
|
|
|
|
// dy_dmean_dx
|
|
|
|
|
auto dx_mean =
|
|
|
|
|
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
|
|
|
|
|
|
|
|
|
|
// dy_var_dx
|
|
|
|
|
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(scale_map.replicate(left, 1))
|
|
|
|
|
.cwiseProduct(d_y_map)
|
|
|
|
|
.rowwise()
|
|
|
|
|
.sum();
|
|
|
|
|
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
|
|
|
|
|
.cwiseProduct(dvar_end_part)
|
|
|
|
|
.replicate(1, right);
|
|
|
|
|
auto dx_var =
|
|
|
|
|
(T(-1.0) / right) *
|
|
|
|
|
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
|
|
|
|
|
|
|
|
|
|
d_x_map = dx_end + dx_mean + dx_var;
|
|
|
|
|
} else {
|
|
|
|
|
// dy_dx
|
|
|
|
|
auto dx_end = inv_std_map.replicate(1, right).cwiseProduct(d_y_map);
|
|
|
|
|
|
|
|
|
|
// dy_dmean_dx
|
|
|
|
|
auto dx_mean =
|
|
|
|
|
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
|
|
|
|
|
|
|
|
|
|
// dy_var_dx
|
|
|
|
|
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
|
|
|
|
|
.cwiseProduct(d_y_map)
|
|
|
|
|
.rowwise()
|
|
|
|
|
.sum();
|
|
|
|
|
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
|
|
|
|
|
.cwiseProduct(dvar_end_part)
|
|
|
|
|
.replicate(1, right);
|
|
|
|
|
auto dx_var =
|
|
|
|
|
(T(-1.0) / right) *
|
|
|
|
|
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
|
|
|
|
|
|
|
|
|
|
d_x_map = dx_end + dx_mean + dx_var;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|