|
|
|
@ -44,13 +44,15 @@ class DataNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchSize"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchSum"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Means"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Scales"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"), "");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DataNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchSize"), "Input", "BatchSize",
|
|
|
|
|
"DataNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchSum"), "Input", "BatchSum", "DataNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchSquareSum"), "Input", "BatchSquareSum",
|
|
|
|
|
"DataNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Means"), "Output", "Means", "DataNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Scales"), "Output", "Scales", "DataNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "DataNorm");
|
|
|
|
|
bool enable_scale_and_shift =
|
|
|
|
|
ctx->Attrs().Get<bool>("enable_scale_and_shift");
|
|
|
|
|
if (enable_scale_and_shift) {
|
|
|
|
@ -67,20 +69,33 @@ class DataNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
const DataLayout data_layout = framework::StringToDataLayout(
|
|
|
|
|
ctx->Attrs().Get<std::string>("data_layout"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
|
|
|
|
|
"Input X must have 2 to 5 dimensions.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size() >= 2 && x_dims.size() <= 5, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input X must have 2 to 5 dimensions."));
|
|
|
|
|
|
|
|
|
|
const int64_t C =
|
|
|
|
|
(data_layout == DataLayout::kNCHW ? x_dims[1]
|
|
|
|
|
: x_dims[x_dims.size() - 1]);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input dim of BatchSize shouold be 1"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input dim of BatchSum shouold be 1"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input dim of BatchSquareSum shouold be 1"));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input dim[0] of BatchSize shouold be C"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input dim[0] of BatchSum shouold be C"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input dim[0] of BatchSqureSum shouold be C"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (enable_scale_and_shift) {
|
|
|
|
@ -141,13 +156,16 @@ class DataNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(dn_param_type,
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"),
|
|
|
|
|
"BatchSize input should be of float type");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"BatchSize input should be of float type"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dn_param_type,
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"),
|
|
|
|
|
"BatchSum input should be of float type");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"BatchSum input should be of float type"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType(
|
|
|
|
|
ctx, "BatchSquareSum"),
|
|
|
|
|
"BatchSquareSum input should be of float type");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"BatchSquareSum input should be of float type"));
|
|
|
|
|
|
|
|
|
|
bool enable_scale_and_shift = ctx.Attr<bool>("enable_scale_and_shift");
|
|
|
|
|
if (enable_scale_and_shift) {
|
|
|
|
@ -183,8 +201,9 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<float>("epsilon", "")
|
|
|
|
|
.SetDefault(1e-4)
|
|
|
|
|
.AddCustomChecker([](const float &epsilon) {
|
|
|
|
|
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
|
|
|
|
|
"'epsilon' should be between 0.0 and 0.001.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'epsilon' should be between 0.0 and 0.001."));
|
|
|
|
|
});
|
|
|
|
|
AddAttr<int>("slot_dim",
|
|
|
|
|
"(int, default -1) Dimension of one slot if set, "
|
|
|
|
@ -256,7 +275,8 @@ class DataNormKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
|
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"The Input dim size should be 2"));
|
|
|
|
|
const int N = x_dims[0];
|
|
|
|
|
const int C =
|
|
|
|
|
(data_layout == DataLayout::kNCHW ? x_dims[1]
|
|
|
|
@ -379,8 +399,9 @@ class DataNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
// check input
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DataNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
|
|
|
|
|
framework::GradVarName("Y"), "DataNormGrad");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("BatchSize"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
@ -393,15 +414,19 @@ class DataNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->HasOutput("BatchSquareSum"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output(BatchSquareSum) of DataNormGradOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Means"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Scales"), "");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Means"), "Input", "Means", "DataNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Scales"), "Input", "Scales", "DataNormGrad");
|
|
|
|
|
bool enable_scale_and_shift =
|
|
|
|
|
ctx->Attrs().Get<bool>("enable_scale_and_shift");
|
|
|
|
|
// check output
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSquareSum")),
|
|
|
|
|
"");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BatchSize")),
|
|
|
|
|
"Output", framework::GradVarName("BatchSize"),
|
|
|
|
|
"DataNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BatchSum")), "Output",
|
|
|
|
|
framework::GradVarName("BatchSum"), "DataNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BatchSquareSum")),
|
|
|
|
|
"Output", framework::GradVarName("BatchSquareSum"),
|
|
|
|
|
"DataNormGrad");
|
|
|
|
|
|
|
|
|
|
const auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const DataLayout data_layout = framework::StringToDataLayout(
|
|
|
|
@ -486,7 +511,8 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
// Get the size for each dimension.
|
|
|
|
|
// NCHW [batch_size, in_channels, in_height, in_width]
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"The Input dim size should be 2"));
|
|
|
|
|
const int N = x_dims[0];
|
|
|
|
|
const int C =
|
|
|
|
|
(data_layout == DataLayout::kNCHW ? x_dims[1]
|
|
|
|
|