add unfold op (new op),test=develop (#17944)
* add unfold op test=develop * fix divide bug in python3 when calculating output width and height test=develop * add name=None in python api, move redundant code into inline function * try to trigger ci for this code test=developlite
parent
b5c35ae3e7
commit
40885c225b
@ -0,0 +1,184 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/unfold_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class UnfoldOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"Tensor, "
|
||||
"the input of unfold op. "
|
||||
"The format of X is [N, C_in, H, W], "
|
||||
"where N is the batch size, C_in is the input channels, "
|
||||
"H is the height and W is the width");
|
||||
AddOutput(
|
||||
"Y",
|
||||
"Tensor, "
|
||||
"the output of unfold op. "
|
||||
"The format of Y is [N, C_in*filter_height*filter_width, "
|
||||
"output_height*output_width], where N is the batch size, "
|
||||
"C_in is the input channels of X, filter_height and filter_width is "
|
||||
"height and width of the filtering kernel, output_height and "
|
||||
"output_width "
|
||||
"is the calculated height and width of output feature map.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"kernel_sizes",
|
||||
"vector<int>, the kernel sizes of the convolution operator.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"strides", "vector<int>, the strides of the convolution operator.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"paddings",
|
||||
"vector<int>, the paddings applied to pad the feature map.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"dilations", "vector<int>, the dilations of the convolution operator.");
|
||||
AddComment(R"DOC(
|
||||
**Unfold Operator**
|
||||
|
||||
This Operator is used to extract sliding local blocks from a batched input tensor, also known
|
||||
as im2col when operated on batched 2D image tensor. For each block under the convolution filter,
|
||||
all element will be rearranged as a column. While the convolution filter silding over the input
|
||||
feature map, a series of such columns will be formed.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class UnfoldOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of UnfoldOp should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Y"),
|
||||
"Output(Y) of UnfoldOp should not be null");
|
||||
auto in_dims = ctx->GetInputDim("X");
|
||||
std::vector<int> kernel_sizes =
|
||||
ctx->Attrs().Get<std::vector<int>>("kernel_sizes");
|
||||
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
|
||||
std::vector<int> dilations =
|
||||
ctx->Attrs().Get<std::vector<int>>("dilations");
|
||||
|
||||
// Only [N, C, H, W] input supported now
|
||||
PADDLE_ENFORCE(
|
||||
in_dims.size() == 4,
|
||||
"Input shold be 4-D tensor of format [N, C, H, W], but get %u",
|
||||
in_dims.size());
|
||||
PADDLE_ENFORCE(
|
||||
in_dims.size() - kernel_sizes.size() == 2U,
|
||||
"The dims of X should be larger than that of kernel_sizes "
|
||||
"by a number of 2, due to the batch size and input channel dim. "
|
||||
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
|
||||
in_dims.size(), kernel_sizes.size());
|
||||
PADDLE_ENFORCE_EQ(
|
||||
strides.size(), kernel_sizes.size(),
|
||||
"The dims of strides shold be the same with that of kernel_sizes. "
|
||||
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
|
||||
strides.size(), kernel_sizes.size());
|
||||
PADDLE_ENFORCE_EQ(
|
||||
paddings.size(), 2 * strides.size(),
|
||||
"The dims of paddings should be 2 times of that of strides. "
|
||||
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
|
||||
paddings.size(), strides.size());
|
||||
PADDLE_ENFORCE_EQ(
|
||||
strides.size(), dilations.size(),
|
||||
"The dims of strides shold be the same with that of dilations. "
|
||||
"But recieved dims(strides: %u) != dims(dilations: %u).",
|
||||
strides.size(), dilations.size());
|
||||
|
||||
std::vector<int> out_dims;
|
||||
out_dims.push_back(in_dims[0]);
|
||||
|
||||
int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
|
||||
out_dims.push_back(output_channels);
|
||||
|
||||
int output_height =
|
||||
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
|
||||
paddings[2], strides[0]);
|
||||
int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1],
|
||||
paddings[1], paddings[3], strides[1]);
|
||||
int output_col_length = output_height * output_width;
|
||||
out_dims.push_back(output_col_length);
|
||||
|
||||
ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class UnfoldGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
||||
"The gradient of Y should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"The gradient of X should not be null");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Y"))->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class UnfoldGradDescMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
||||
op->SetType("unfold_grad");
|
||||
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
|
||||
op->SetInput("X", Input("X"));
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
op->SetAttrMap(Attrs());
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(UnfoldGradOpNoNeedBufferVarsInference,
|
||||
"X");
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker,
|
||||
ops::UnfoldGradDescMaker);
|
||||
REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp,
|
||||
ops::UnfoldGradOpNoNeedBufferVarsInference);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
unfold, ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::UnfoldOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
unfold_grad,
|
||||
ops::UnfoldGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::UnfoldGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/unfold_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
unfold, ops::UnfoldOpKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::UnfoldOpKernel<paddle::platform::CUDADeviceContext, double>);
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
unfold_grad,
|
||||
ops::UnfoldGradOpKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::UnfoldGradOpKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,127 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/im2col.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
inline int CalcOutputSize(int input_size, int filter_size, int dilation,
|
||||
int padding1, int padding2, int stride) {
|
||||
const int dkernel = dilation * (filter_size - 1) + 1;
|
||||
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
|
||||
PADDLE_ENFORCE(output_size > 0,
|
||||
"Due to the settings of padding(%d, %d), filter_size(%d), "
|
||||
"dilation(%d) and "
|
||||
"stride(%d), the output size is less than 0, please check "
|
||||
"again. Input_size:%d",
|
||||
padding1, padding2, filter_size, dilation, stride, input_size);
|
||||
|
||||
return output_size;
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class UnfoldOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const Tensor* input = ctx.Input<Tensor>("X");
|
||||
const int batch_size = static_cast<int>(input->dims()[0]);
|
||||
Tensor* output = ctx.Output<Tensor>("Y");
|
||||
output->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
std::vector<int> kernel_sizes = ctx.Attr<std::vector<int>>("kernel_sizes");
|
||||
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
|
||||
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
||||
|
||||
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
|
||||
int output_height =
|
||||
CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0],
|
||||
paddings[0], paddings[2], strides[0]);
|
||||
int output_width =
|
||||
CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1],
|
||||
paddings[1], paddings[3], strides[1]);
|
||||
|
||||
framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]});
|
||||
framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0],
|
||||
kernel_sizes[1], output_height,
|
||||
output_width});
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
|
||||
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
|
||||
im2col(dev_ctx, in_batch, dilations, strides, paddings, &out_batch);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class UnfoldGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const Tensor* output_grad = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
input_grad->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
if ((!output_grad) || (!input_grad)) return;
|
||||
|
||||
std::vector<int> kernel_sizes = ctx.Attr<std::vector<int>>("kernel_sizes");
|
||||
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
|
||||
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
||||
|
||||
const int batch_size = static_cast<int>(input_grad->dims()[0]);
|
||||
|
||||
auto input_dims = input_grad->dims();
|
||||
|
||||
int output_height =
|
||||
CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0],
|
||||
paddings[0], paddings[2], strides[0]);
|
||||
int output_width =
|
||||
CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1],
|
||||
paddings[1], paddings[3], strides[1]);
|
||||
|
||||
framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]});
|
||||
framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0],
|
||||
kernel_sizes[1], output_height,
|
||||
output_width});
|
||||
|
||||
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
set_zero(dev_ctx, input_grad, static_cast<T>(0));
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
Tensor out_grad_batch =
|
||||
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
|
||||
Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
|
||||
col2im(dev_ctx, out_grad_batch, dilations, strides, paddings,
|
||||
&in_grad_batch);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,102 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import unittest
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestUnfoldOp(OpTest):
|
||||
"""
|
||||
This is for test on unfold Op
|
||||
"""
|
||||
|
||||
def init_data(self):
|
||||
self.batch_size = 3
|
||||
self.input_channels = 3
|
||||
self.input_height = 20
|
||||
self.input_width = 20
|
||||
self.kernel_sizes = [3, 3]
|
||||
self.strides = [1, 1]
|
||||
self.paddings = [1, 1, 1, 1]
|
||||
self.dilations = [1, 1]
|
||||
input_shape = [
|
||||
self.batch_size, self.input_channels, self.input_height,
|
||||
self.input_width
|
||||
]
|
||||
self.x = np.random.rand(*input_shape).astype(np.float32)
|
||||
|
||||
def calc_unfold(self):
|
||||
output_shape = [0] * 3
|
||||
output_shape[0] = self.batch_size
|
||||
output_shape[1] = self.input_channels * self.kernel_sizes[
|
||||
0] * self.kernel_sizes[1]
|
||||
dkernel_h = self.dilations[0] * (self.kernel_sizes[0] - 1) + 1
|
||||
dkernel_w = self.dilations[1] * (self.kernel_sizes[1] - 1) + 1
|
||||
out_height = int((self.input_height + self.paddings[0] +
|
||||
self.paddings[2] - dkernel_h) / self.strides[0]) + 1
|
||||
out_width = int((self.input_width + self.paddings[1] + self.paddings[3]
|
||||
- dkernel_w) / self.strides[1]) + 1
|
||||
output_shape[2] = out_height * out_width
|
||||
output = np.zeros(output_shape).astype(np.float32)
|
||||
############ calculate output ##############
|
||||
for i in range(output_shape[0]):
|
||||
for j in range(output_shape[1]):
|
||||
for k in range(output_shape[2]):
|
||||
h_out = int(k / out_width)
|
||||
w_out = k % out_width
|
||||
w_offset = j % self.kernel_sizes[1]
|
||||
h_offset = int(j /
|
||||
self.kernel_sizes[1]) % self.kernel_sizes[0]
|
||||
c_in = int(j /
|
||||
(self.kernel_sizes[0] * self.kernel_sizes[1]))
|
||||
h_in = h_offset * self.dilations[0] + h_out * self.strides[
|
||||
0] - self.paddings[0]
|
||||
w_in = w_offset * self.dilations[1] + w_out * self.strides[
|
||||
1] - self.paddings[1]
|
||||
if (h_in>=0 and h_in<self.input_height) and \
|
||||
(w_in>=0 and w_in<self.input_width):
|
||||
output[i, j, k] = self.x[i, c_in, h_in, w_in]
|
||||
|
||||
self.outputs = output
|
||||
|
||||
def set_data(self):
|
||||
self.init_data()
|
||||
self.calc_unfold()
|
||||
|
||||
self.inputs = {'X': self.x}
|
||||
self.attrs = {
|
||||
'kernel_sizes': self.kernel_sizes,
|
||||
'paddings': self.paddings,
|
||||
'dilations': self.dilations,
|
||||
'strides': self.strides
|
||||
}
|
||||
self.outputs = {'Y': self.outputs}
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = 'unfold'
|
||||
self.set_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Y')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue