move param exclusive to the last in pool2d/pool3d for forward compatibility:. test=develop

fix_recordio_link
dengkaipeng 6 years ago
parent c93e044ae0
commit 8f1e398824

@ -67,8 +67,8 @@ paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))

@ -29,9 +29,9 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const std::vector<int>& ksize, const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int>& strides,
PoolProcess pool_process, bool exclusive, const std::vector<int>& paddings, PoolProcess pool_process,
framework::Tensor* output) { bool exclusive, framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
@ -69,7 +69,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
} }
} }
int pool_size = exclusive ? (hend - hstart) * (wend - wstart) int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width; : ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele); pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[ph * output_width + pw] = ele; output_data[ph * output_width + pw] = ele;
} }
@ -126,7 +126,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart) int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width; : ksize_height * ksize_width;
float scale = 1.0 / pool_size; float scale = 1.0 / pool_size;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
@ -249,8 +249,8 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const std::vector<int>& ksize, const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int>& strides,
PoolProcess pool_process, const std::vector<int>& paddings, PoolProcess pool_process,
bool exclusive, framework::Tensor* output) { bool exclusive, framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2]; const int input_depth = input.dims()[2];
@ -301,9 +301,10 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
} }
} }
} }
int pool_size = exclusive ? int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart) exclusive
: ksize_depth * ksize_height * ksize_width; ? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele); pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[output_idx] = ele; output_data[output_idx] = ele;
} }
@ -371,9 +372,10 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
int pool_size = exclusive ? int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart) exclusive
: ksize_depth * ksize_height * ksize_width; ? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
float scale = 1.0 / pool_size; float scale = 1.0 / pool_size;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {

@ -53,7 +53,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
} }
} }
int pool_size = exclusive ? (hend - hstart) * (wend - wstart) int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width; : ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele); pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[index] = ele; output_data[index] = ele;
} }
@ -97,7 +97,7 @@ __global__ void KernelPool2DGrad(
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart) int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width; : ksize_height * ksize_width;
int output_sub_idx = ph * output_width + pw; int output_sub_idx = ph * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx], pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx], output_grad[output_sub_idx],
@ -191,7 +191,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>( KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height, output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, pool_process, exclusive, stride_width, padding_height, padding_width, pool_process, exclusive,
output_data); output_data);
} }
}; };
@ -317,11 +317,11 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
__global__ void KernelPool3D( __global__ void KernelPool3D(
const int nthreads, const T* input_data, const int channels, const int nthreads, const T* input_data, const int channels,
const int input_depth, const int input_height, const int input_width, const int input_depth, const int input_height, const int input_width,
const int output_depth, const int output_height, const int output_width, const int output_depth, const int output_height, const int output_width,
const int ksize_depth, const int ksize_height, const int ksize_width, const int ksize_depth, const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height, const int stride_width, const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height, const int padding_width, const int padding_depth, const int padding_height, const int padding_width,
PoolProcess pool_process, bool exclusive, T* output_data) { PoolProcess pool_process, bool exclusive, T* output_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
@ -352,9 +352,9 @@ __global__ void KernelPool3D(
} }
} }
} }
int pool_size = exclusive ? int pool_size = exclusive
(dend - dstart) * (hend - hstart) * (wend - wstart) ? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width; : ksize_depth * ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele); pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[index] = ele; output_data[index] = ele;
} }
@ -412,9 +412,9 @@ __global__ void KernelPool3DGrad(
dstart = max(dstart, 0); dstart = max(dstart, 0);
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
int pool_size = exclusive ? int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart) exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width; : ksize_depth * ksize_height * ksize_width;
int output_sub_idx = (pd * output_height + ph) * output_width + pw; int output_sub_idx = (pd * output_height + ph) * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx], pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx], output_grad[output_sub_idx],
@ -522,8 +522,8 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads, input_data, input_channels, input_depth, input_height, nthreads, input_data, input_channels, input_depth, input_height,
input_width, output_depth, output_height, output_width, ksize_depth, input_width, output_depth, output_height, output_width, ksize_depth,
ksize_height, ksize_width, stride_depth, stride_height, stride_width, ksize_height, ksize_width, stride_depth, stride_height, stride_width,
padding_depth, padding_height, padding_width, pool_process, padding_depth, padding_height, padding_width, pool_process, exclusive,
exclusive, output_data); output_data);
} }
}; };

@ -73,7 +73,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
if (pooling_type == "max") { if (pooling_type == "max") {
pooling_mode = PoolingMode::kMaximum; pooling_mode = PoolingMode::kMaximum;
} else { } else {
pooling_mode = exclusive ? PoolingMode::kAverageExclusive : PoolingMode::kAverageInclusive; pooling_mode = exclusive ? PoolingMode::kAverageExclusive
: PoolingMode::kAverageInclusive;
} }
cudnnPoolingDescriptor_t cudnn_pool_desc = cudnnPoolingDescriptor_t cudnn_pool_desc =
@ -143,7 +144,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
pooling_mode = PoolingMode::kMaximum; pooling_mode = PoolingMode::kMaximum;
} }
} else { } else {
pooling_mode = exclusive ? PoolingMode::kAverageExclusive : PoolingMode::kAverageInclusive; pooling_mode = exclusive ? PoolingMode::kAverageExclusive
: PoolingMode::kAverageInclusive;
} }
cudnnPoolingDescriptor_t cudnn_pool_desc = cudnnPoolingDescriptor_t cudnn_pool_desc =

@ -2067,8 +2067,8 @@ def pool2d(input,
global_pooling=False, global_pooling=False,
use_cudnn=True, use_cudnn=True,
ceil_mode=False, ceil_mode=False,
exclusive=True, name=None,
name=None): exclusive=True):
""" """
${comment} ${comment}
@ -2085,10 +2085,10 @@ def pool2d(input,
global_pooling (bool): ${global_pooling_comment} global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment} use_cudnn (bool): ${use_cudnn_comment}
ceil_mode (bool): ${ceil_mode_comment} ceil_mode (bool): ${ceil_mode_comment}
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
name (str|None): A name for this layer(optional). If set None, the name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically. layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns: Returns:
Variable: The pooling result. Variable: The pooling result.
@ -2161,8 +2161,8 @@ def pool3d(input,
global_pooling=False, global_pooling=False,
use_cudnn=True, use_cudnn=True,
ceil_mode=False, ceil_mode=False,
exclusive=True, name=None,
name=None): exclusive=True):
""" """
This function adds the operator for pooling in 3-dimensions, using the This function adds the operator for pooling in 3-dimensions, using the
pooling configurations mentioned in input parameters. pooling configurations mentioned in input parameters.
@ -2176,10 +2176,10 @@ def pool3d(input,
global_pooling (bool): ${global_pooling_comment} global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment} use_cudnn (bool): ${use_cudnn_comment}
ceil_mode (bool): ${ceil_mode_comment} ceil_mode (bool): ${ceil_mode_comment}
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
name (str): A name for this layer(optional). If set None, the layer name (str): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns: Returns:
Variable: output of pool3d layer. Variable: output of pool3d layer.

@ -96,9 +96,9 @@ class TestPool2d_Op(OpTest):
if self.global_pool: if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))] self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype(self.dtype) input = np.random.random(self.shape).astype(self.dtype)
output = self.pool2D_forward_naive(input, self.ksize, self.strides, output = self.pool2D_forward_naive(
self.paddings, self.global_pool, input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive).astype(self.dtype) self.ceil_mode, self.exclusive).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = { self.attrs = {
@ -110,7 +110,8 @@ class TestPool2d_Op(OpTest):
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_mkldnn,
'ceil_mode': self.ceil_mode, 'ceil_mode': self.ceil_mode,
'data_format': 'AnyLayout', # TODO(dzhwinter) : should be fix latter 'data_format':
'AnyLayout', # TODO(dzhwinter) : should be fix latter
'exclusive': self.exclusive 'exclusive': self.exclusive
} }
@ -329,10 +330,12 @@ class TestCeilModeCase4(TestCase2):
def init_ceil_mode(self): def init_ceil_mode(self):
self.ceil_mode = True self.ceil_mode = True
class TestAvgInclude(TestCase2): class TestAvgInclude(TestCase2):
def init_exclusive(self): def init_exclusive(self):
self.exclusive = False self.exclusive = False
class TestCUDNNAvgInclude(TestCUDNNCase3): class TestCUDNNAvgInclude(TestCUDNNCase3):
def init_exclusive(self): def init_exclusive(self):
self.exclusive = False self.exclusive = False

@ -89,7 +89,8 @@ def avg_pool3D_forward_naive(x,
field_size = (d_end - d_start) * (h_end - h_start) * (w_end - w_start) \ field_size = (d_end - d_start) * (h_end - h_start) * (w_end - w_start) \
if exclusive else ksize[0] * ksize[1] * ksize[2] if exclusive else ksize[0] * ksize[1] * ksize[2]
out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3, 4)) / field_size out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3,
4)) / field_size
return out return out
@ -108,9 +109,9 @@ class TestPool3d_Op(OpTest):
if self.global_pool: if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))] self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype(self.dtype) input = np.random.random(self.shape).astype(self.dtype)
output = self.pool3D_forward_naive(input, self.ksize, self.strides, output = self.pool3D_forward_naive(
self.paddings, self.global_pool, input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive).astype(self.dtype) self.ceil_mode, self.exclusive).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = { self.attrs = {
@ -121,8 +122,9 @@ class TestPool3d_Op(OpTest):
'global_pooling': self.global_pool, 'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'ceil_mode': self.ceil_mode, 'ceil_mode': self.ceil_mode,
'data_format': 'AnyLayout', # TODO(dzhwinter) : should be fix latter 'data_format':
'exclusive': self.exclusive 'AnyLayout', # TODO(dzhwinter) : should be fix latter
'exclusive': self.exclusive
} }
self.outputs = {'Out': output} self.outputs = {'Out': output}
@ -167,7 +169,7 @@ class TestPool3d_Op(OpTest):
self.ceil_mode = False self.ceil_mode = False
def init_exclusive(self): def init_exclusive(self):
self.exclusive = True self.exclusive = True
class TestCase1(TestPool3d_Op): class TestCase1(TestPool3d_Op):
@ -340,10 +342,12 @@ class TestCeilModeCase4(TestCase2):
def init_ceil_mode(self): def init_ceil_mode(self):
self.ceil_mode = True self.ceil_mode = True
class TestAvgInclude(TestCase2): class TestAvgInclude(TestCase2):
def init_exclusive(self): def init_exclusive(self):
self.exclusive = False self.exclusive = False
class TestCUDNNAvgInclude(TestCUDNNCase3): class TestCUDNNAvgInclude(TestCUDNNCase3):
def init_exclusive(self): def init_exclusive(self):
self.exclusive = False self.exclusive = False

Loading…
Cancel
Save