[2.0RC]refine error message in shuffle channel OP (#27505)

* refine err msg in shuffle channel op
revert-27520-disable_pr
ruri 4 years ago committed by GitHub
parent 4236367401
commit e1fb77d123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,13 +21,13 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ShuffleChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ShuffleChannelOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument("The layout of input is NCHW."));
ctx->SetOutputDim("Out", input_dims);
}
@ -53,7 +53,8 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("group", "the number of groups.")
.SetDefault(1)
.AddCustomChecker([](const int& group) {
PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0.");
PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument(
"group should be larger than 0."));
});
AddComment(R"DOC(
@ -76,7 +77,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument("The layout of input is NCHW."));
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}

Loading…
Cancel
Save