add precise roi pooling op test=develop (#18960)
* add precise roi pooling op test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * detail the description test=develop * test=develop * elaborate the doc for return type test=develop * test=developexpand_as_op_1
parent
3cd985a669
commit
a7c440d303
@ -0,0 +1,188 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/prroi_pool_op.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor), "
|
||||
"the input of PRROIPoolOp. "
|
||||
"The format of input tensor is NCHW. Where N is the batch size, "
|
||||
"C is the number of input channels, "
|
||||
"H is the height of the input feature map, and "
|
||||
"W is the width.");
|
||||
AddInput("ROIs",
|
||||
"(LoDTensor), "
|
||||
"ROIs (Regions of Interest) to pool over. "
|
||||
"should be a 2-D LoDTensor of shape (num_rois, 4) "
|
||||
"given as [(x1, y1, x2, y2), ...]. "
|
||||
"where (x1, y1) is the top left coordinates, and "
|
||||
"(x2, y2) is the bottom right coordinates. "
|
||||
"The roi batch index can be calculated from LoD.");
|
||||
AddOutput("Out",
|
||||
"(Tensor), "
|
||||
"the output of PRROIPoolOp is a 4-D Tensor with shape "
|
||||
"(num_rois, output_channels, pooled_h, pooled_w).");
|
||||
AddAttr<int>(
|
||||
"output_channels",
|
||||
"(int), "
|
||||
"the number of channels of the output feature map. "
|
||||
"For a task of C classes of objects, output_channels should be "
|
||||
"(C + 1) for classification only.");
|
||||
AddAttr<float>("spatial_scale",
|
||||
"(float, default 1.0), "
|
||||
"Multiplicative spatial scale factor "
|
||||
"to translate ROI coords from their input scale "
|
||||
"to the scale used when pooling.")
|
||||
.SetDefault(1.0);
|
||||
AddAttr<int>("pooled_height",
|
||||
"(int, default 1), "
|
||||
"the pooled output height.")
|
||||
.SetDefault(1);
|
||||
AddAttr<int>("pooled_width",
|
||||
"(int, default 1), "
|
||||
"the pooled output width.")
|
||||
.SetDefault(1);
|
||||
AddComment(R"Doc(
|
||||
**PRROIPool Operator**
|
||||
|
||||
Precise region of interest pooling (also known as PRROIPooling) is to perform
|
||||
bilinear interpolation average pooling method for RoI Pooling.
|
||||
|
||||
Please refer to https://arxiv.org/abs/1807.11590 for more details.
|
||||
|
||||
)Doc");
|
||||
}
|
||||
};
|
||||
|
||||
class PRROIPoolOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
"Input(X) of op(PRROIPool) should not be null.");
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
|
||||
"Input(ROIs) of op(PRROIPool) should not be null.");
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
"Output(Out) of op(PRROIPool) should not be null.");
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
auto rois_dims = ctx->GetInputDim("ROIs");
|
||||
|
||||
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
|
||||
"The format of input tensor is NCHW");
|
||||
PADDLE_ENFORCE_EQ(rois_dims.size(), 2,
|
||||
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
|
||||
"given as [(x1, y1, x2, y2), ...]");
|
||||
PADDLE_ENFORCE_EQ(rois_dims[1], 4,
|
||||
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
|
||||
"given as [(x1, y1, x2, y2), ...]");
|
||||
|
||||
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
|
||||
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
|
||||
int output_channels = ctx->Attrs().Get<int>("output_channels");
|
||||
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
input_dims[1], output_channels * pooled_height * pooled_width,
|
||||
"the channel of X(%d) should be equal to the product of "
|
||||
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
|
||||
input_dims[1], output_channels, pooled_height, pooled_width);
|
||||
|
||||
PADDLE_ENFORCE_GT(pooled_height, 0,
|
||||
"The pooled output height must be greater than 0");
|
||||
PADDLE_ENFORCE_GT(pooled_width, 0,
|
||||
"The pooled output width must be greater than 0");
|
||||
PADDLE_ENFORCE_GT(output_channels, 1,
|
||||
"The pooled output channels must greater than 1");
|
||||
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
|
||||
"The spatial scale must greater than 0.");
|
||||
|
||||
auto out_dims = input_dims;
|
||||
out_dims[0] = rois_dims[0];
|
||||
out_dims[1] =
|
||||
output_channels; // input_dims[1] / (pooled_height * pooled_width);
|
||||
out_dims[2] = pooled_height;
|
||||
out_dims[3] = pooled_width;
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class PRROIPoolGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||
"The gradient of Out should not be null.");
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
||||
"The gradient of X should not be null.");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class PRROIPoolGradDescMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
||||
op->SetType("prroi_pool_grad");
|
||||
op->SetInput("X", Input("X"));
|
||||
op->SetInput("ROIs", Input("ROIs"));
|
||||
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
op->SetAttrMap(Attrs());
|
||||
return op;
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(prroi_pool, ops::PRROIPoolOp, ops::PRROIPoolOpMaker,
|
||||
ops::PRROIPoolGradDescMaker);
|
||||
REGISTER_OPERATOR(prroi_pool_grad, ops::PRROIPoolGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
prroi_pool,
|
||||
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
prroi_pool_grad,
|
||||
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PyPrRoIPool(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _PrRoIPoolingGetData(self, data, h, w, height, width):
|
||||
overflow = (h < 0) or (w < 0) or (h >= height) or (w >= width)
|
||||
if overflow:
|
||||
return 0.0
|
||||
else:
|
||||
return data[h][w]
|
||||
|
||||
def _PrRoIPoolingMatCalculation(self, this_data, s_h, s_w, e_h, e_w, y0, x0,
|
||||
y1, x1, h0, w0):
|
||||
sum_out = 0.0
|
||||
alpha = x0 - float(s_w)
|
||||
beta = y0 - float(s_h)
|
||||
lim_alpha = x1 - float(s_w)
|
||||
lim_beta = y1 - float(s_h)
|
||||
tmp = (
|
||||
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
|
||||
alpha) * (
|
||||
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
|
||||
sum_out += self._PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp
|
||||
|
||||
alpha = float(e_w) - x1
|
||||
lim_alpha = float(e_w) - x0
|
||||
tmp = (
|
||||
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
|
||||
alpha) * (
|
||||
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
|
||||
sum_out += self._PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp
|
||||
|
||||
alpha = x0 - float(s_w)
|
||||
beta = float(e_h) - y1
|
||||
lim_alpha = x1 - float(s_w)
|
||||
lim_beta = float(e_h) - y0
|
||||
tmp = (
|
||||
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
|
||||
alpha) * (
|
||||
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
|
||||
sum_out += self._PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp
|
||||
|
||||
alpha = float(e_w) - x1
|
||||
lim_alpha = float(e_w) - x0
|
||||
tmp = (
|
||||
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
|
||||
alpha) * (
|
||||
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
|
||||
sum_out += self._PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp
|
||||
|
||||
return sum_out
|
||||
|
||||
def compute(self,
|
||||
x,
|
||||
rois,
|
||||
output_channels,
|
||||
spatial_scale=0.1,
|
||||
pooled_height=1,
|
||||
pooled_width=1):
|
||||
'''
|
||||
calculate the precise roi pooling values
|
||||
Note: This function is implements as pure python without any paddle concept involved
|
||||
:param x (array): array[N, C, H, W]
|
||||
:param rois (array): ROIs[id, x1, y1, x2, y2] (Regions of Interest) to pool over.
|
||||
:param output_channels (Integer): Expected output channels
|
||||
:param spatial_scale (float): spatial scale, default = 0.1
|
||||
:param pooled_height (Integer): Expected output height, default = 1
|
||||
:param pooled_width (Integer): Expected output width, default = 1
|
||||
:return: array[len(rois), output_channels, pooled_height, pooled_width]
|
||||
'''
|
||||
if not isinstance(output_channels, int):
|
||||
raise TypeError("output_channels must be int type")
|
||||
if not isinstance(spatial_scale, float):
|
||||
raise TypeError("spatial_scale must be float type")
|
||||
if not isinstance(pooled_height, int):
|
||||
raise TypeError("pooled_height must be int type")
|
||||
if not isinstance(pooled_width, int):
|
||||
raise TypeError("pooled_width must be int type")
|
||||
|
||||
(batch_size, channels, height, width) = np.array(x).shape
|
||||
rois_num = len(rois)
|
||||
output_shape = (rois_num, output_channels, pooled_height, pooled_width)
|
||||
out_data = np.zeros(output_shape)
|
||||
for i in range(rois_num):
|
||||
roi = rois[i]
|
||||
roi_batch_id = int(roi[0])
|
||||
roi_start_w = roi[1] * spatial_scale
|
||||
roi_start_h = roi[2] * spatial_scale
|
||||
roi_end_w = roi[3] * spatial_scale
|
||||
roi_end_h = roi[4] * spatial_scale
|
||||
|
||||
roi_width = max(roi_end_w - roi_start_w, 0.0)
|
||||
roi_height = max(roi_end_h - roi_start_h, 0.0)
|
||||
bin_size_h = roi_height / float(pooled_height)
|
||||
bin_size_w = roi_width / float(pooled_width)
|
||||
|
||||
x_i = x[roi_batch_id]
|
||||
|
||||
for c in range(output_channels):
|
||||
for ph in range(pooled_height):
|
||||
for pw in range(pooled_width):
|
||||
win_start_w = roi_start_w + bin_size_w * pw
|
||||
win_start_h = roi_start_h + bin_size_h * ph
|
||||
win_end_w = win_start_w + bin_size_w
|
||||
win_end_h = win_start_h + bin_size_h
|
||||
|
||||
win_size = max(0.0, bin_size_w * bin_size_h)
|
||||
if win_size == 0.0:
|
||||
out_data[i, c, ph, pw] = 0.0
|
||||
else:
|
||||
sum_out = 0
|
||||
|
||||
s_w = math.floor(win_start_w)
|
||||
e_w = math.ceil(win_end_w)
|
||||
s_h = math.floor(win_start_h)
|
||||
e_h = math.ceil(win_end_h)
|
||||
|
||||
c_in = (c * pooled_height + ph) * pooled_width + pw
|
||||
|
||||
for w_iter in range(int(s_w), int(e_w)):
|
||||
for h_iter in range(int(s_h), int(e_h)):
|
||||
sum_out += self._PrRoIPoolingMatCalculation(
|
||||
x_i[c_in], h_iter, w_iter, h_iter + 1,
|
||||
w_iter + 1,
|
||||
max(win_start_h, float(h_iter)),
|
||||
max(win_start_w, float(w_iter)),
|
||||
min(win_end_h, float(h_iter) + 1.0),
|
||||
min(win_end_w, float(w_iter + 1.0)),
|
||||
height, width)
|
||||
|
||||
out_data[i, c, ph, pw] = sum_out / win_size
|
||||
|
||||
return out_data
|
@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
from py_precise_roi_pool import PyPrRoIPool
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import compiler, Program, program_guard
|
||||
|
||||
|
||||
class TestPRROIPoolOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_case()
|
||||
self.make_rois()
|
||||
self.prRoIPool = PyPrRoIPool()
|
||||
self.outs = self.prRoIPool.compute(
|
||||
self.x, self.rois, self.output_channels, self.spatial_scale,
|
||||
self.pooled_height, self.pooled_width).astype('float32')
|
||||
self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)}
|
||||
self.attrs = {
|
||||
'output_channels': self.output_channels,
|
||||
'spatial_scale': self.spatial_scale,
|
||||
'pooled_height': self.pooled_height,
|
||||
'pooled_width': self.pooled_width
|
||||
}
|
||||
self.outputs = {'Out': self.outs}
|
||||
|
||||
def init_test_case(self):
|
||||
self.batch_size = 3
|
||||
self.channels = 3 * 2 * 2
|
||||
self.height = 6
|
||||
self.width = 4
|
||||
|
||||
self.x_dim = [self.batch_size, self.channels, self.height, self.width]
|
||||
|
||||
self.spatial_scale = 1.0 / 4.0
|
||||
self.output_channels = 3
|
||||
self.pooled_height = 2
|
||||
self.pooled_width = 2
|
||||
|
||||
self.x = np.random.random(self.x_dim).astype('float32')
|
||||
|
||||
def make_rois(self):
|
||||
rois = []
|
||||
self.rois_lod = [[]]
|
||||
for bno in range(self.batch_size):
|
||||
self.rois_lod[0].append(bno + 1)
|
||||
for i in range(bno + 1):
|
||||
x1 = np.random.random_integers(
|
||||
0, self.width // self.spatial_scale - self.pooled_width)
|
||||
y1 = np.random.random_integers(
|
||||
0, self.height // self.spatial_scale - self.pooled_height)
|
||||
|
||||
x2 = np.random.random_integers(x1 + self.pooled_width,
|
||||
self.width // self.spatial_scale)
|
||||
y2 = np.random.random_integers(
|
||||
y1 + self.pooled_height, self.height // self.spatial_scale)
|
||||
roi = [bno, x1, y1, x2, y2]
|
||||
rois.append(roi)
|
||||
self.rois_num = len(rois)
|
||||
self.rois = np.array(rois).astype('float32')
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = 'prroi_pool'
|
||||
self.set_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_backward(self):
|
||||
for place in self._get_places():
|
||||
self._get_gradient(['X'], place, ["Out"], None)
|
||||
|
||||
def run_net(self, place):
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(
|
||||
name="X",
|
||||
shape=[self.channels, self.height, self.width],
|
||||
dtype="float32")
|
||||
rois = fluid.layers.data(
|
||||
name="ROIs", shape=[4], dtype="float32", lod_level=1)
|
||||
output = fluid.layers.prroi_pool(x, rois, self.output_channels,
|
||||
0.25, 2, 2)
|
||||
loss = fluid.layers.mean(output)
|
||||
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
|
||||
optimizer.minimize(loss)
|
||||
input_x = fluid.create_lod_tensor(self.x, [], place)
|
||||
input_rois = fluid.create_lod_tensor(self.rois[:, 1:5],
|
||||
self.rois_lod, place)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
exe.run(fluid.default_main_program(),
|
||||
{'X': input_x,
|
||||
"ROIs": input_rois})
|
||||
|
||||
def test_net(self):
|
||||
places = [fluid.CPUPlace()]
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
places.append(fluid.CUDAPlace(0))
|
||||
for place in places:
|
||||
self.run_net(place)
|
||||
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
x = fluid.layers.data(
|
||||
name="x", shape=[245, 30, 30], dtype="float32")
|
||||
rois = fluid.layers.data(
|
||||
name="rois", shape=[4], dtype="float32", lod_level=1)
|
||||
# channel must be int type
|
||||
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.5,
|
||||
0.25, 7, 7)
|
||||
# spatial_scale must be float type
|
||||
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, 2,
|
||||
7, 7)
|
||||
# pooled_height must be int type
|
||||
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5,
|
||||
0.25, 0.7, 7)
|
||||
# pooled_width must be int type
|
||||
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5,
|
||||
0.25, 7, 0.7)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue