add precise roi pooling op test=develop (#18960)

* add precise roi pooling op test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* detail the description test=develop

* test=develop

* elaborate the doc for return type test=develop

* test=develop
expand_as_op_1
wopeizl 6 years ago committed by GitHub
parent 3cd985a669
commit a7c440d303
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -276,6 +276,7 @@ paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], vararg
paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', '13b1cdcb01f5ffdc26591ff9a2ec4669'))
paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '8404e472ac12b4a30a505d3d3a3e5fdb'))
paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '42d5155374f69786300d90d751956998'))
paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '454c7ea8c73313dd41513929d7526303'))
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '07cb0d95a646dba1b9cc7cdce89e59f0'))
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '11bb8e62cc9256958eff3991fe4834da'))
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '18bc95c62d3300456c3c7da5278b47bb'))

@ -0,0 +1,188 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/prroi_pool_op.h"
#include <memory>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), "
"the input of PRROIPoolOp. "
"The format of input tensor is NCHW. Where N is the batch size, "
"C is the number of input channels, "
"H is the height of the input feature map, and "
"W is the width.");
AddInput("ROIs",
"(LoDTensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]. "
"where (x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates. "
"The roi batch index can be calculated from LoD.");
AddOutput("Out",
"(Tensor), "
"the output of PRROIPoolOp is a 4-D Tensor with shape "
"(num_rois, output_channels, pooled_h, pooled_w).");
AddAttr<int>(
"output_channels",
"(int), "
"the number of channels of the output feature map. "
"For a task of C classes of objects, output_channels should be "
"(C + 1) for classification only.");
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling.")
.SetDefault(1.0);
AddAttr<int>("pooled_height",
"(int, default 1), "
"the pooled output height.")
.SetDefault(1);
AddAttr<int>("pooled_width",
"(int, default 1), "
"the pooled output width.")
.SetDefault(1);
AddComment(R"Doc(
**PRROIPool Operator**
Precise region of interest pooling (also known as PRROIPooling) is to perform
bilinear interpolation average pooling method for RoI Pooling.
Please refer to https://arxiv.org/abs/1807.11590 for more details.
)Doc");
}
};
class PRROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of op(PRROIPool) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("ROIs"), true,
"Input(ROIs) of op(PRROIPool) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of op(PRROIPool) should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
"The format of input tensor is NCHW");
PADDLE_ENFORCE_EQ(rois_dims.size(), 2,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]");
PADDLE_ENFORCE_EQ(rois_dims[1], 4,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]");
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
int output_channels = ctx->Attrs().Get<int>("output_channels");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
PADDLE_ENFORCE_EQ(
input_dims[1], output_channels * pooled_height * pooled_width,
"the channel of X(%d) should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
input_dims[1], output_channels, pooled_height, pooled_width);
PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled output height must be greater than 0");
PADDLE_ENFORCE_GT(pooled_width, 0,
"The pooled output width must be greater than 0");
PADDLE_ENFORCE_GT(output_channels, 1,
"The pooled output channels must greater than 1");
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0.");
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] =
output_channels; // input_dims[1] / (pooled_height * pooled_width);
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Out", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
class PRROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"The gradient of Out should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"The gradient of X should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
class PRROIPoolGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("prroi_pool_grad");
op->SetInput("X", Input("X"));
op->SetInput("ROIs", Input("ROIs"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(prroi_pool, ops::PRROIPoolOp, ops::PRROIPoolOpMaker,
ops::PRROIPoolGradDescMaker);
REGISTER_OPERATOR(prroi_pool_grad, ops::PRROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
prroi_pool,
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPRROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
prroi_pool_grad,
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUPRROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -203,6 +203,7 @@ __all__ = [
'temporal_shift',
'py_func',
'psroi_pool',
'prroi_pool',
'teacher_student_sigmoid_loss',
'huber_loss',
'kldiv_loss',
@ -12716,6 +12717,70 @@ def psroi_pool(input,
return out
@templatedoc()
def prroi_pool(input,
rois,
output_channels,
spatial_scale=1.0,
pooled_height=1,
pooled_width=1,
name=None):
"""
The precise roi pooling implementation for paddle?https://arxiv.org/pdf/1807.11590.pdf
Args:
input (Variable):The input of Deformable PSROIPooling.The shape of input tensor is
[N,C,H,W]. Where N is batch size,C is number of input channels,H
is height of the feature, and W is the width of the feature.
rois (Variable): ROIs (Regions of Interest) to pool over.It should be
a 2-D LoDTensor of shape (num_rois, 4), the lod level
is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is
the top left coordinates, and (x2, y2) is the bottom
right coordinates.
output_channels (integer): The output's channel.
spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width).
Equals the reciprocal of total stride in convolutional layers, Default: 1.0.
pooled_height (integer): The pooled output height. Default: 1.
pooled_width (integer): The pooled output width. Default: 1.
name (str, default None): The name of this operation.
Returns:
Variable(Tensor): The shape of the returned Tensor is (num_rois, output_channels, pooled_h, pooled_w), with value type float32,float16..
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[490, 28, 28], dtype='float32')
rois = fluid.layers.data(name='rois', shape=[4], lod_level=1, dtype='float32')
pool_out = fluid.layers.prroi_pool(x, rois, 10, 1.0, 7, 7)
"""
helper = LayerHelper('prroi_pool', **locals())
# check attrs
if not isinstance(output_channels, int):
raise TypeError("output_channels must be int type")
if not isinstance(spatial_scale, float):
raise TypeError("spatial_scale must be float type")
if not isinstance(pooled_height, int):
raise TypeError("pooled_height must be int type")
if not isinstance(pooled_width, int):
raise TypeError("pooled_width must be int type")
dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='prroi_pool',
inputs={'X': input,
'ROIs': rois},
outputs={'Out': out},
attrs={
'output_channels': output_channels,
'spatial_scale': spatial_scale,
'pooled_height': pooled_height,
'pooled_width': pooled_width
})
return out
def huber_loss(input, label, delta):
"""
Huber loss is a loss function used in robust.

@ -0,0 +1,151 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
class PyPrRoIPool(object):
def __init__(self):
pass
def _PrRoIPoolingGetData(self, data, h, w, height, width):
overflow = (h < 0) or (w < 0) or (h >= height) or (w >= width)
if overflow:
return 0.0
else:
return data[h][w]
def _PrRoIPoolingMatCalculation(self, this_data, s_h, s_w, e_h, e_w, y0, x0,
y1, x1, h0, w0):
sum_out = 0.0
alpha = x0 - float(s_w)
beta = y0 - float(s_h)
lim_alpha = x1 - float(s_w)
lim_beta = y1 - float(s_h)
tmp = (
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
alpha) * (
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
sum_out += self._PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp
alpha = float(e_w) - x1
lim_alpha = float(e_w) - x0
tmp = (
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
alpha) * (
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
sum_out += self._PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp
alpha = x0 - float(s_w)
beta = float(e_h) - y1
lim_alpha = x1 - float(s_w)
lim_beta = float(e_h) - y0
tmp = (
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
alpha) * (
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
sum_out += self._PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp
alpha = float(e_w) - x1
lim_alpha = float(e_w) - x0
tmp = (
lim_alpha - 0.5 * lim_alpha * lim_alpha - alpha + 0.5 * alpha *
alpha) * (
lim_beta - 0.5 * lim_beta * lim_beta - beta + 0.5 * beta * beta)
sum_out += self._PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp
return sum_out
def compute(self,
x,
rois,
output_channels,
spatial_scale=0.1,
pooled_height=1,
pooled_width=1):
'''
calculate the precise roi pooling values
Note: This function is implements as pure python without any paddle concept involved
:param x (array): array[N, C, H, W]
:param rois (array): ROIs[id, x1, y1, x2, y2] (Regions of Interest) to pool over.
:param output_channels (Integer): Expected output channels
:param spatial_scale (float): spatial scale, default = 0.1
:param pooled_height (Integer): Expected output height, default = 1
:param pooled_width (Integer): Expected output width, default = 1
:return: array[len(rois), output_channels, pooled_height, pooled_width]
'''
if not isinstance(output_channels, int):
raise TypeError("output_channels must be int type")
if not isinstance(spatial_scale, float):
raise TypeError("spatial_scale must be float type")
if not isinstance(pooled_height, int):
raise TypeError("pooled_height must be int type")
if not isinstance(pooled_width, int):
raise TypeError("pooled_width must be int type")
(batch_size, channels, height, width) = np.array(x).shape
rois_num = len(rois)
output_shape = (rois_num, output_channels, pooled_height, pooled_width)
out_data = np.zeros(output_shape)
for i in range(rois_num):
roi = rois[i]
roi_batch_id = int(roi[0])
roi_start_w = roi[1] * spatial_scale
roi_start_h = roi[2] * spatial_scale
roi_end_w = roi[3] * spatial_scale
roi_end_h = roi[4] * spatial_scale
roi_width = max(roi_end_w - roi_start_w, 0.0)
roi_height = max(roi_end_h - roi_start_h, 0.0)
bin_size_h = roi_height / float(pooled_height)
bin_size_w = roi_width / float(pooled_width)
x_i = x[roi_batch_id]
for c in range(output_channels):
for ph in range(pooled_height):
for pw in range(pooled_width):
win_start_w = roi_start_w + bin_size_w * pw
win_start_h = roi_start_h + bin_size_h * ph
win_end_w = win_start_w + bin_size_w
win_end_h = win_start_h + bin_size_h
win_size = max(0.0, bin_size_w * bin_size_h)
if win_size == 0.0:
out_data[i, c, ph, pw] = 0.0
else:
sum_out = 0
s_w = math.floor(win_start_w)
e_w = math.ceil(win_end_w)
s_h = math.floor(win_start_h)
e_h = math.ceil(win_end_h)
c_in = (c * pooled_height + ph) * pooled_width + pw
for w_iter in range(int(s_w), int(e_w)):
for h_iter in range(int(s_h), int(e_h)):
sum_out += self._PrRoIPoolingMatCalculation(
x_i[c_in], h_iter, w_iter, h_iter + 1,
w_iter + 1,
max(win_start_h, float(h_iter)),
max(win_start_w, float(w_iter)),
min(win_end_h, float(h_iter) + 1.0),
min(win_end_w, float(w_iter + 1.0)),
height, width)
out_data[i, c, ph, pw] = sum_out / win_size
return out_data

@ -0,0 +1,138 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
from py_precise_roi_pool import PyPrRoIPool
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestPRROIPoolOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.prRoIPool = PyPrRoIPool()
self.outs = self.prRoIPool.compute(
self.x, self.rois, self.output_channels, self.spatial_scale,
self.pooled_height, self.pooled_width).astype('float32')
self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = {
'output_channels': self.output_channels,
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width
}
self.outputs = {'Out': self.outs}
def init_test_case(self):
self.batch_size = 3
self.channels = 3 * 2 * 2
self.height = 6
self.width = 4
self.x_dim = [self.batch_size, self.channels, self.height, self.width]
self.spatial_scale = 1.0 / 4.0
self.output_channels = 3
self.pooled_height = 2
self.pooled_width = 2
self.x = np.random.random(self.x_dim).astype('float32')
def make_rois(self):
rois = []
self.rois_lod = [[]]
for bno in range(self.batch_size):
self.rois_lod[0].append(bno + 1)
for i in range(bno + 1):
x1 = np.random.random_integers(
0, self.width // self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
0, self.height // self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width // self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height // self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype('float32')
def setUp(self):
self.op_type = 'prroi_pool'
self.set_data()
def test_check_output(self):
self.check_output()
def test_backward(self):
for place in self._get_places():
self._get_gradient(['X'], place, ["Out"], None)
def run_net(self, place):
with program_guard(Program(), Program()):
x = fluid.layers.data(
name="X",
shape=[self.channels, self.height, self.width],
dtype="float32")
rois = fluid.layers.data(
name="ROIs", shape=[4], dtype="float32", lod_level=1)
output = fluid.layers.prroi_pool(x, rois, self.output_channels,
0.25, 2, 2)
loss = fluid.layers.mean(output)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
input_x = fluid.create_lod_tensor(self.x, [], place)
input_rois = fluid.create_lod_tensor(self.rois[:, 1:5],
self.rois_lod, place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run(fluid.default_main_program(),
{'X': input_x,
"ROIs": input_rois})
def test_net(self):
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.run_net(place)
def test_errors(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(
name="x", shape=[245, 30, 30], dtype="float32")
rois = fluid.layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1)
# channel must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 0.5,
0.25, 7, 7)
# spatial_scale must be float type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5, 2,
7, 7)
# pooled_height must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5,
0.25, 0.7, 7)
# pooled_width must be int type
self.assertRaises(TypeError, fluid.layers.prroi_pool, x, rois, 5,
0.25, 7, 0.7)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save