Update OP_INOUT_CHECK (#23757)

* update NotFound -> OP_INOUT_CHECK: grid_sampler, kldiv_loss, spectral_norm, temporal_shift. test=develop
revert-23830-2.0-beta
Kaipeng Deng 5 years ago committed by GitHub
parent 9e85d02373
commit 63232e4941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,16 +28,9 @@ class GridSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of GridSampleOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grid"), true,
platform::errors::NotFound(
"Input(Grid) of GridSampleOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Output"), true,
platform::errors::NotFound(
"Output(Output) of GridSampleOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GridSampler");
OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid", "GridSampler");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "GridSampler");
auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");

@ -23,15 +23,9 @@ class KLDivLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of KLDivLossOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Target"), true,
platform::errors::NotFound(
"Input(Target) of KLDivLossOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Loss"), true,
platform::errors::NotFound(
"Output(Loss) of KLDivLossOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLoss");
OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLoss");
OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "KLDivLoss");
auto dim_x = ctx->GetInputDim("X");
auto dim_target = ctx->GetInputDim("Target");
@ -135,15 +129,10 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Target"), true,
platform::errors::NotFound("Input(Target) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Loss")), true,
platform::errors::NotFound("Input(Loss@GRAD) should not be null"));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLossGrad");
OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLossGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Loss")), "Input",
"Loss@GRAD", "KLDivLossGrad");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);

@ -26,19 +26,10 @@ class SpectralNormOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Weight"), true,
platform::errors::NotFound(
"Input(Weight) of SpectralNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("U"), true,
platform::errors::NotFound(
"Input(U) of SpectralNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("V"), true,
platform::errors::NotFound(
"Input(V) of SpectralNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of SpectralNormOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
auto dim_weight = ctx->GetInputDim("Weight");
auto rank_weight = dim_weight.size();
@ -220,15 +211,13 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Weight"), true,
platform::errors::NotFound("Input(Weight) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("U"), true,
platform::errors::NotFound("Input(U) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("V"), true,
platform::errors::NotFound("Input(V) should not be null"));
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight",
"SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "SpectralNormGrad");
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@GRAD) should not be null"));

@ -26,13 +26,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of TemporalShiftOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of TemporalShiftOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4,

Loading…
Cancel
Save