|
|
|
@ -23,27 +23,30 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of Instance Norm Op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Scale"), true,
|
|
|
|
|
"Input(Scale) of Instance Norm Op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Bias"), true,
|
|
|
|
|
"Input(Bias) of Instance Norm Op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
|
|
|
|
|
"Output(Y) of Instance Norm Op should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("SavedMean"), true,
|
|
|
|
|
"Output(SavedMean) of Instance Norm Op should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("SavedVariance"), true,
|
|
|
|
|
"Output(SavedVariance) of Instance Norm Op should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InstanceNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "InstanceNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "InstanceNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("SavedMean"), "Output", "SavedMean",
|
|
|
|
|
"InstanceNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("SavedVariance"), "Output", "SavedVariance",
|
|
|
|
|
"InstanceNorm");
|
|
|
|
|
|
|
|
|
|
const auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"the dimension of input X must greater than or equal to 2");
|
|
|
|
|
PADDLE_ENFORCE_LE(x_dims.size(), 5,
|
|
|
|
|
"the dimension of input X must smaller than or equal to 5");
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: the dimension of input X must "
|
|
|
|
|
"greater than or equal to 2. But received: the shape of input "
|
|
|
|
|
"X = [%s], the dimension of input X =[%d]",
|
|
|
|
|
x_dims, x_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
x_dims.size(), 5,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: the dimension of input X must "
|
|
|
|
|
"smaller than or equal to 5, But received: the shape of input "
|
|
|
|
|
"X = [%s], the dimension of input X = [%d]",
|
|
|
|
|
x_dims, x_dims.size()));
|
|
|
|
|
auto N = x_dims[0];
|
|
|
|
|
auto C = x_dims[1];
|
|
|
|
|
auto NxC = N * C;
|
|
|
|
@ -51,15 +54,34 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
auto scale_dim = ctx->GetInputDim("Scale");
|
|
|
|
|
auto bias_dim = ctx->GetInputDim("Bias");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dim.size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
scale_dim.size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: the dimension of scale must equal to 1."
|
|
|
|
|
"But received: the shape of scale is [%s], the dimension "
|
|
|
|
|
"of scale is [%d]",
|
|
|
|
|
scale_dim, scale_dim.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dim.size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: the dimension of bias must equal to 1."
|
|
|
|
|
"But received: the shape of bias is [%s],the dimension "
|
|
|
|
|
"of bias is [%d]",
|
|
|
|
|
bias_dim, bias_dim.size()));
|
|
|
|
|
|
|
|
|
|
bool check = !((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
|
|
|
|
|
framework::product(bias_dim) <= 0));
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dim[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: the shape of scale must equal to [%d]"
|
|
|
|
|
"But received: the shape of scale is [%d]",
|
|
|
|
|
C, scale_dim[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(bias_dim[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ShapeError: the shape of bias must equal to [%d]"
|
|
|
|
|
"But received: the shape of bias is [%d]",
|
|
|
|
|
C, bias_dim[0]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Y", x_dims);
|
|
|
|
@ -78,10 +100,12 @@ framework::OpKernelType InstanceNormOp::GetExpectedKernelType(
|
|
|
|
|
if (input_data_type == framework::proto::VarType::FP64) {
|
|
|
|
|
in_param_type = framework::proto::VarType::FP64;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_param_type, ctx.Input<Tensor>("Scale")->type(),
|
|
|
|
|
"Scale input should be of float type");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_param_type, ctx.Input<Tensor>("Bias")->type(),
|
|
|
|
|
"Bias input should be of float type");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_param_type, ctx.Input<Tensor>("Scale")->type(),
|
|
|
|
|
platform::errors::InvalidArgument("Scale input should be of float type"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_param_type, ctx.Input<Tensor>("Bias")->type(),
|
|
|
|
|
platform::errors::InvalidArgument("Bias input should be of float type"));
|
|
|
|
|
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
@ -91,7 +115,8 @@ void InstanceNormOpMaker::Make() {
|
|
|
|
|
.SetDefault(1e-5)
|
|
|
|
|
.AddCustomChecker([](const float &epsilon) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
|
|
|
|
|
"'epsilon' should be between 0.0 and 0.001.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"'epsilon' should be between 0.0 and 0.001."));
|
|
|
|
|
});
|
|
|
|
|
AddInput("X", "The input tensor");
|
|
|
|
|
AddInput("Scale",
|
|
|
|
@ -193,24 +218,21 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void InstanceNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Scale"), true,
|
|
|
|
|
"Input(scale) should not be null");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Y")), true,
|
|
|
|
|
"Input(Y@GRAD) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedMean"), true,
|
|
|
|
|
"Input(SavedMean) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedVariance"), true,
|
|
|
|
|
"Input(SavedVariance) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InstanceNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
|
|
|
|
|
framework::GradVarName("Y"), "InstanceNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
|
|
|
|
|
"InstanceNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
|
|
|
|
|
"InstanceNormGrad");
|
|
|
|
|
|
|
|
|
|
// check output
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|
"Output(x@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
|
|
|
|
framework::GradVarName("X"), "InstanceNormGrad");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Bias")), true,
|
|
|
|
|
"Output(Scale@GRAD) and Output(Bias@GRAD) should not be "
|
|
|
|
|
"null at the same time");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Bias")), "Output",
|
|
|
|
|
framework::GradVarName("Bias"), "InstanceNormGrad");
|
|
|
|
|
}
|
|
|
|
|
const auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const int C = x_dims[1];
|
|
|
|
@ -333,21 +355,20 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
|
|
|
|
|
void InstanceNormDoubleGradOp::InferShape(
|
|
|
|
|
framework::InferShapeContext *ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Scale"), true,
|
|
|
|
|
"Input(Scale) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedMean"), true,
|
|
|
|
|
"Input(SavedMean) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedVariance"), true,
|
|
|
|
|
"Input(SavedVariance) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true,
|
|
|
|
|
"Input(DDX) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("DY"), true,
|
|
|
|
|
"Input(Y@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormDoubleGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
|
|
|
|
|
"InstanceNormDoubleGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
|
|
|
|
|
"InstanceNormDoubleGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
|
|
|
|
|
"InstanceNormDoubleGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("DDX"), "Input", "DDX",
|
|
|
|
|
"InstanceNormDoubleGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("DY"), "Input", "DY", "InstanceNormDoubleGrad");
|
|
|
|
|
|
|
|
|
|
// check output
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("DX"), true,
|
|
|
|
|
"Output(DX) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("DX"), "Output", "DX",
|
|
|
|
|
"InstanceNormDoubleGrad");
|
|
|
|
|
|
|
|
|
|
const auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const int C = x_dims[1];
|
|
|
|
|