|
|
|
@ -12,21 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/block_expand_op.h"
|
|
|
|
|
#include "paddle/operators/im2sequence_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class BlockExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
class Im2SequenceOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input of BlockExpandOp should not be null.");
|
|
|
|
|
"Input(X) of Im2SequenceOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output of BlockExpandOp op should not be null.");
|
|
|
|
|
"Output(Out) of Im2SequenceOp op should not be null.");
|
|
|
|
|
|
|
|
|
|
auto in_dim = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dim.size(), 4,
|
|
|
|
@ -55,9 +55,9 @@ class BlockExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
BlockExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
Im2SequenceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor)The input tensor has NCHW format."
|
|
|
|
@ -65,7 +65,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"C: channels"
|
|
|
|
|
"H: height"
|
|
|
|
|
"W: width");
|
|
|
|
|
AddOutput("Out", "(LodTensor)The output data of block_expand op,");
|
|
|
|
|
AddOutput("Out", "(LodTensor)The output data of im2sequence op,");
|
|
|
|
|
AddAttr<int>("block_height", "(int)height of block.");
|
|
|
|
|
AddAttr<int>("block_width", "(int)width of block.");
|
|
|
|
|
AddAttr<int>("stride_height", "(int)height of stride.");
|
|
|
|
@ -73,7 +73,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<int>("padding_height", "(int)height of padding.");
|
|
|
|
|
AddAttr<int>("padding_width", "(int)width of padding.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Expand feature map to minibatch matrix.
|
|
|
|
|
Convert feature map to minibatch matrix.
|
|
|
|
|
- matirx height is: output_height * output_width
|
|
|
|
|
- matrix width is: block_height * block_width * channels
|
|
|
|
|
|
|
|
|
@ -133,7 +133,7 @@ output.lod = [[0, 4, 8]]
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class BlockExpandGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
class Im2SequenceGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -150,11 +150,11 @@ class BlockExpandGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker,
|
|
|
|
|
block_expand_grad, ops::BlockExpandGradOp);
|
|
|
|
|
REGISTER_OP(im2sequence, ops::Im2SequenceOp, ops::Im2SequenceOpMaker,
|
|
|
|
|
im2sequence_grad, ops::Im2SequenceGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
block_expand,
|
|
|
|
|
ops::BlockExpandKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
im2sequence,
|
|
|
|
|
ops::Im2SequenceKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
block_expand_grad,
|
|
|
|
|
ops::BlockExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
im2sequence_grad,
|
|
|
|
|
ops::Im2SequenceGradKernel<paddle::platform::CPUDeviceContext, float>);
|