|
|
|
@ -26,19 +26,10 @@ class SpectralNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Weight"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(Weight) of SpectralNormOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("U"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(U) of SpectralNormOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("V"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(V) of SpectralNormOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output(Out) of SpectralNormOp should not be null."));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "SpectralNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
|
|
|
|
|
|
|
|
|
|
auto dim_weight = ctx->GetInputDim("Weight");
|
|
|
|
|
auto rank_weight = dim_weight.size();
|
|
|
|
@ -220,15 +211,13 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Weight"), true,
|
|
|
|
|
platform::errors::NotFound("Input(Weight) should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("U"), true,
|
|
|
|
|
platform::errors::NotFound("Input(U) should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("V"), true,
|
|
|
|
|
platform::errors::NotFound("Input(V) should not be null"));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight",
|
|
|
|
|
"SpectralNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNormGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
"Out@GRAD", "SpectralNormGrad");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
platform::errors::NotFound("Input(Out@GRAD) should not be null"));
|
|
|
|
|