|
|
|
@ -21,13 +21,13 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of ShuffleChannelOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of ShuffleChannelOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp");
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_dims.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument("The layout of input is NCHW."));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", input_dims);
|
|
|
|
|
}
|
|
|
|
@ -53,7 +53,8 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<int>("group", "the number of groups.")
|
|
|
|
|
.SetDefault(1)
|
|
|
|
|
.AddCustomChecker([](const int& group) {
|
|
|
|
|
PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0.");
|
|
|
|
|
PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument(
|
|
|
|
|
"group should be larger than 0."));
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
@ -76,7 +77,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_dims.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument("The layout of input is NCHW."));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
|
|
|
|
|
}
|
|
|
|
|