|
|
|
@ -27,36 +27,62 @@ class LayerNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of LayerNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"),
|
|
|
|
|
"Output(Y) of LayerNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
|
|
|
|
|
"Output(Mean) of LayerNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
|
|
|
|
|
"Output(Variance) of LayerNormOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LayerNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "LayerNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Mean"), "Output", "Mean", "LayerNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Variance"), "Output", "Variance",
|
|
|
|
|
"LayerNorm");
|
|
|
|
|
|
|
|
|
|
auto x_dim = ctx->GetInputDim("X");
|
|
|
|
|
auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
|
|
|
|
|
PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(),
|
|
|
|
|
"'begin_norm_axis' must be less than the rank of X.");
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
begin_norm_axis, x_dim.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'begin_norm_axis' must be less than the dimensions of X,"
|
|
|
|
|
"But received 'begin_norm_axis' is [%d],"
|
|
|
|
|
"received the dimensions of X is [%d].",
|
|
|
|
|
begin_norm_axis, x_dim.size()));
|
|
|
|
|
|
|
|
|
|
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
|
|
|
|
|
int left = static_cast<int>(matrix_dim[0]);
|
|
|
|
|
int right = static_cast<int>(matrix_dim[1]);
|
|
|
|
|
if (ctx->HasInput("Scale")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimensions of Input(Scale) must be 1, but "
|
|
|
|
|
"received dimensions of"
|
|
|
|
|
"Input(Scale) is [%d]",
|
|
|
|
|
ctx->GetInputDim("Scale").size()));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right,
|
|
|
|
|
"scale should with right");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("Scale")[0], right,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension value of Input(Scale) must equal to be the"
|
|
|
|
|
"second dimension value of the flattened 2D matrix of Input(X),"
|
|
|
|
|
"But received the first dimension value of Input(Scale) is"
|
|
|
|
|
"[%d], the second dimension value of the flattened 2D matrix of"
|
|
|
|
|
" Input(Scale) is [%d].",
|
|
|
|
|
ctx->GetInputDim("Scale")[0], right));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimensions of Input(Bias) must be 1, but "
|
|
|
|
|
"received dimensions of"
|
|
|
|
|
"Input(Bias) is [%d]",
|
|
|
|
|
ctx->GetInputDim("Bias").size()));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right,
|
|
|
|
|
"bias should with right");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputDim("Bias")[0], right,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension value of Input(Bias) must equal to be the"
|
|
|
|
|
"second dimension value of the flattened 2D matrix of Input(X),"
|
|
|
|
|
"But received the first dimension value of Input(Bias) is"
|
|
|
|
|
"[%d], the second dimension value of the flattened 2D matrix of"
|
|
|
|
|
" Input(Bias) is [%d].",
|
|
|
|
|
ctx->GetInputDim("Scale")[0], right));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -90,8 +116,11 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"Constant for numerical stability [default 1e-5].")
|
|
|
|
|
.SetDefault(1e-5)
|
|
|
|
|
.AddCustomChecker([](const float &epsilon) {
|
|
|
|
|
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
|
|
|
|
|
"'epsilon' should be between 0.0 and 0.001.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'epsilon' in Op(LayerNorm) should be between"
|
|
|
|
|
"0.0 and 0.001, But received [%s].",
|
|
|
|
|
epsilon));
|
|
|
|
|
});
|
|
|
|
|
AddAttr<int>("begin_norm_axis",
|
|
|
|
|
"the axis of `begin_norm_axis ... Rank(X) - 1` will be "
|
|
|
|
@ -100,7 +129,10 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.SetDefault(1)
|
|
|
|
|
.AddCustomChecker([](const int &begin_norm_axis) {
|
|
|
|
|
PADDLE_ENFORCE_GT(begin_norm_axis, 0,
|
|
|
|
|
"'begin_norm_axis' should be greater than zero.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'begin_norm_axis' in Op(LayerNorm) should be"
|
|
|
|
|
"greater than zero. But received [%d].",
|
|
|
|
|
begin_norm_axis));
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
@ -122,14 +154,12 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
// check input
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of LayerNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Mean"),
|
|
|
|
|
"Input(Mean) of LayerNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Variance"),
|
|
|
|
|
"Input(Variance) of LayerNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
|
|
|
|
"Input(Y@GRAD) of LayerNormOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LayerNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "LayerNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance",
|
|
|
|
|
"LayerNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
|
|
|
|
|
framework::GradVarName("Y"), "LayerNormGrad");
|
|
|
|
|
|
|
|
|
|
// check output
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|