|
|
|
@ -26,26 +26,45 @@ class SpectralNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(Weight) of SpectralNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("U"),
|
|
|
|
|
"Input(U) of SpectralNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("V"),
|
|
|
|
|
"Input(V) of SpectralNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SpectralNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Weight"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(Weight) of SpectralNormOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("U"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(U) of SpectralNormOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("V"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(V) of SpectralNormOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output(Out) of SpectralNormOp should not be null."));
|
|
|
|
|
|
|
|
|
|
auto dim_weight = ctx->GetInputDim("Weight");
|
|
|
|
|
auto rank_weight = dim_weight.size();
|
|
|
|
|
PADDLE_ENFORCE(rank_weight >= 2 && rank_weight <= 5,
|
|
|
|
|
"The rank of Input(Weights) can only be 2, 3,"
|
|
|
|
|
"4, 5 for fc, conv1d, conv2d, conv3d layers.");
|
|
|
|
|
PADDLE_ENFORCE_GE(rank_weight, 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(Weights) should be greater equal "
|
|
|
|
|
"than 2, but received Weight rank(%d)",
|
|
|
|
|
rank_weight));
|
|
|
|
|
PADDLE_ENFORCE_LE(rank_weight, 5,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(Weights) should be less equal "
|
|
|
|
|
"than 5, but received Weight rank(%d)",
|
|
|
|
|
rank_weight));
|
|
|
|
|
|
|
|
|
|
int dim = ctx->Attrs().Get<int>("dim");
|
|
|
|
|
int power_iters = ctx->Attrs().Get<int>("power_iters");
|
|
|
|
|
PADDLE_ENFORCE(dim == 0 || dim == 1, "Attr(dim) can only be 0 or 1");
|
|
|
|
|
PADDLE_ENFORCE(power_iters >= 0,
|
|
|
|
|
"Attr(power_iters) should be larger equal then 0");
|
|
|
|
|
auto dim_valid = dim == 0 || dim == 1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_valid, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Attr(dim) can only be 0 or 1, but received %d", dim));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
power_iters, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Attr(power_iters) should be greater equal then 0, but received %d",
|
|
|
|
|
power_iters));
|
|
|
|
|
|
|
|
|
|
int h = dim_weight[dim];
|
|
|
|
|
int w = 1;
|
|
|
|
@ -59,15 +78,22 @@ class SpectralNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_u[0], h,
|
|
|
|
|
"Input(U) dims[0] should be equal to "
|
|
|
|
|
"Input(Weight) dims[Attr(dim)]");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(U) dimension[0] should be equal to "
|
|
|
|
|
"Input(Weight) dimension[Attr(dim)], but received "
|
|
|
|
|
"U dimension[0](%d) != Weight dimension[%d](%d)",
|
|
|
|
|
dim_u[0], dim, h));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_v[0], w,
|
|
|
|
|
"Input(V) dims[0] should be equal to "
|
|
|
|
|
"the product of Input(Weight) dims except dims[Attr(dim)]");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(V) dimension[0] should be equal to the product of "
|
|
|
|
|
"Input(Weight) dimension except dimension[Attr(dim)], but "
|
|
|
|
|
"received V dimension[0](%d) != product of Input(Weight) "
|
|
|
|
|
"dimension(%d)",
|
|
|
|
|
dim_v[0], w));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", dim_weight);
|
|
|
|
@ -194,11 +220,18 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Weight"), true,
|
|
|
|
|
platform::errors::NotFound("Input(Weight) should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("U"), true,
|
|
|
|
|
platform::errors::NotFound("Input(U) should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("V"), true,
|
|
|
|
|
platform::errors::NotFound("Input(V) should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
platform::errors::NotFound("Input(Out@GRAD) should not be null"));
|
|
|
|
|
auto dim_x = ctx->GetInputDim("Weight");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Weight"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x);
|
|
|
|
|