|
|
@ -19,26 +19,27 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx - > HasInput("X"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
"Input(X) of ShuffleChannelOp should not be null.");
|
|
|
|
"Input(X) of ShuffleChannelOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
"Output(Out) of ShuffleChannelOp should not be null.");
|
|
|
|
"Output(Out) of ShuffleChannelOp should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
|
|
|
|
|
|
|
|
|
|
|
// ENFORCE group
|
|
|
|
// ENFORCE group
|
|
|
|
auto group = ctx->Attrs().Get<std::vector<int>>("group");
|
|
|
|
// auto group = ctx->Attrs().Get<int>("group");
|
|
|
|
ctx->SetOutputDim("Out", input_dims);
|
|
|
|
ctx->SetOutputDim("Out", input_dims);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
return framework::OpKernelType(
|
|
|
|
return framework::OpKernelType(
|
|
|
|
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
|
|
|
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
|
|
|
ctx.GetPlace());
|
|
|
|
ctx.device_context());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
*/
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
@ -63,7 +64,7 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
then, feed each group in the next layer with different subgroups.
|
|
|
|
then, feed each group in the next layer with different subgroups.
|
|
|
|
|
|
|
|
|
|
|
|
According to the paper, "Suppose a convolution layer with g groups
|
|
|
|
According to the paper, "Suppose a convolution layer with g groups
|
|
|
|
whose output has g x n channels, first reshape the output channel dimension into(g,n),
|
|
|
|
whose output has g * n channels, first reshape the output channel dimension into(g,n),
|
|
|
|
transposing and then flattening it back as the input of next layer. "
|
|
|
|
transposing and then flattening it back as the input of next layer. "
|
|
|
|
|
|
|
|
|
|
|
|
Shuffle channel operation makes it possible to build more powerful structures
|
|
|
|
Shuffle channel operation makes it possible to build more powerful structures
|
|
|
@ -75,52 +76,49 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// Grad
|
|
|
|
class ShuffleChannelGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
class ShuffleChannelOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
"Input(Out@Grad) should not be null")
|
|
|
|
"Input(Out@Grad) should not be null");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
"Output(X@Grad) should not be null");
|
|
|
|
"Output(X@Grad) should not be null");
|
|
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
return framework::OpKernelType(
|
|
|
|
return framework::OpKernelType(
|
|
|
|
framework::ToDataType(
|
|
|
|
framework::ToDataType(
|
|
|
|
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
|
|
|
|
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
|
|
|
->type()),
|
|
|
|
ctx.device_context());
|
|
|
|
ctx.device_context());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
// how to write gpu kernal
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OPERATOR(shufflechannel, ops::ShuffleChannelOp,
|
|
|
|
REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp,
|
|
|
|
ops::ShuffleChannelOpMaker,
|
|
|
|
ops::ShuffleChannelOpMaker,
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
// paddle::framework::EmptyGradOpMaker);
|
|
|
|
// paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(shufflechannel_grad, ops::ShuffleChannelGradOp);
|
|
|
|
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
shufflechannel,
|
|
|
|
shuffle_channel,
|
|
|
|
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
shufflechannel_grad,
|
|
|
|
shuffle_channel_grad,
|
|
|
|
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
double>);
|
|
|
|
double>);
|
|
|
|