From 6326c40d2709da9839edbf3fb2a280ca92804a23 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 28 Sep 2017 09:49:00 +0800 Subject: [PATCH 01/26] Add max pool with index --- paddle/operators/CMakeLists.txt | 6 + paddle/operators/math/CMakeLists.txt | 4 +- paddle/operators/math/pooling.cc | 255 ++++++++++++ paddle/operators/math/pooling.cu | 387 ++++++++++++++++++ paddle/operators/math/pooling.h | 68 +++ paddle/operators/pool_with_index_op.cc | 198 +++++++++ paddle/operators/pool_with_index_op.cu | 31 ++ paddle/operators/pool_with_index_op.h | 99 +++++ .../v2/framework/tests/test_pool_max_op.py | 125 ++++++ 9 files changed, 1171 insertions(+), 2 deletions(-) create mode 100644 paddle/operators/math/pooling.cc create mode 100644 paddle/operators/math/pooling.cu create mode 100644 paddle/operators/math/pooling.h create mode 100644 paddle/operators/pool_with_index_op.cc create mode 100644 paddle/operators/pool_with_index_op.cu create mode 100644 paddle/operators/pool_with_index_op.h create mode 100644 python/paddle/v2/framework/tests/test_pool_max_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e56895c63a..0feb969c62 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -62,6 +62,12 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") endif() + if ("${TARGET}" STREQUAL "pool_with_index_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(maxPool2dWithIndex);\n") + endif() + # pybind USE_NO_KERNEL_OP file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}") diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 91ae3d49f1..811deb4c2c 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,12 +1,12 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu DEPS cblas device_context operator) + im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator) nv_library(softmax_function SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu DEPS operator) else() - cc_library(math_function SRCS math_function.cc im2col.cc + cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_library(softmax_function SRCS softmax.cc DEPS operator) cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc new file mode 100644 index 0000000000..0e4d9007a6 --- /dev/null +++ b/paddle/operators/math/pooling.cc @@ -0,0 +1,255 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + ele = input_data[h * input_width + w]; + index = h * input_width + w; + } + } + } + output_data[ph * output_width + pw] = ele; + mask_data[ph * output_width + pw] = index; + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (size_t n = 0; n < batch_size; ++n) { + for (size_t c = 0; c < output_channels; ++c) { + for (size_t ph = 0; ph < output_height; ++ph) { + for (size_t pw = 0; pw < output_width; ++pw) { + const size_t output_idx = ph * output_width + pw; + const size_t input_idx = static_cast(mask_data[output_idx]); + + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + int output_idx = (pd * output_height + ph) * output_width + pw; + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < + input_data[(d * input_height + h) * input_width + w]) { + index = (d * input_height + h) * input_width + w; + ele = + input_data[(d * input_height + h) * input_width + w]; + } + } + } + } + output_data[output_idx] = ele; + mask_data[output_idx] = index; + } + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (size_t n = 0; n < batch_size; ++n) { + for (size_t c = 0; c < output_channels; ++c) { + for (size_t pd = 0; pd < output_depth; ++pd) { + for (size_t ph = 0; ph < output_height; ++ph) { + for (size_t pw = 0; pw < output_width; ++pw) { + const size_t output_idx = + (pd * output_height + ph) * output_width + pw; + const size_t input_idx = + static_cast(mask_data[output_idx]); + + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu new file mode 100644 index 0000000000..f32e6a26d0 --- /dev/null +++ b/paddle/operators/math/pooling.cu @@ -0,0 +1,387 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/pooling.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void KernelMaxPool2dWithIdxForward( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < nthreads) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + T ele = -FLT_MAX; + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + index = h * input_width + w; + ele = input_data[h * input_width + w]; + } + } + } + output_data[index] = ele; + mask_data[index] = index; + } +} + +template +__global__ void KernelMaxPool2DWithIdxBackward( + const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < nthreads) { + int offsetW = index % input_width + padding_width; + int offsetH = (index / input_width) % input_height + padding_height; + int offsetC = (index / input_width / input_height) % channels; + int batch_idx = index / input_width / input_height / channels; + + int phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + int pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + int phend = min(offsetH / stride_height + 1, output_height); + int pwend = min(offsetW / stride_width + 1, output_width); + T gradient = 0; + int output_idx = + (batch_idx * channels + offsetC) * output_height * output_width; + mask_data += output_idx; + output_grad += output_idx; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if ((offsetH * input_width + offsetW) == + mask_data[ph * output_width + pw]) + gradient += output_grad[ph * output_width + pw]; + } + } + input_grad[index] = gradient; + } +} + +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2dWithIdxForward< + T><<(context) + .stream()>>>(nthreads, input_data, output_data, mask_data, + input_channels, input_height, input_width, + output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, + padding_height, padding_width); + } +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2DWithIdxBackward< + T><<(context) + .stream()>>>(nthreads, input_grad_data, output_grad_data, + mask_data, input_channels, input_height, + input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width); + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +__global__ void KernelMaxPool3DWithIdxForward( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + T ele = -FLT_MAX; + int index = -1; + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[(d * input_height + h) * input_width + w]) { + index = (d * input_height + h) * input_width + w; + ele = input_data[(d * input_height + h) * input_width + w]; + } + } + } + } + output_data[index] = ele; + mask_data[index] = index; + } +} + +template +__global__ void KernelMaxPool3DWithIdxBackward( + const int nthreads, T* input_grad, const T* output_grad, const T* mask, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int offsetW = index % input_width + padding_width; + int offsetH = (index / input_width) % input_height + padding_height; + int offsetD = + (index / input_width / input_height) % input_depth + padding_depth; + int offsetC = (index / input_width / input_height / input_depth) % channels; + int batch_idx = index / input_width / input_height / input_depth / channels; + + int pdstart = (offsetD < ksize_depth) + ? 0 + : (offsetD - ksize_depth) / stride_depth + 1; + int phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + int pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + int pdend = min((offsetD) / stride_depth + 1, output_depth); + int phend = min((offsetH) / stride_height + 1, output_height); + int pwend = min((offsetW) / stride_width + 1, output_width); + + T gradient = 0; + int output_idx = (batch_idx * channels + offsetC) * output_depth * + output_height * output_width; + mask += output_idx; + output_grad += output_idx; + + for (int pd = pdstart; pd < pdend; ++pd) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (((offsetD * input_height + offsetH) * input_width + offsetW) == + mask[(pd * output_height + ph) * output_width + pw]) + gradient += + output_grad[(pd * output_height + ph) * output_width + pw]; + } + } + } + input_grad[index] = gradient; + } +} + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = output.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdxForward< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = input_grad.dims()[1]; + const int output_depth = input_grad.dims()[2]; + const int output_height = input_grad.dims()[3]; + const int output_width = input_grad.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* output_grad_data = output_grad.data(); + const T* mask_data = mask.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = + batch_size * input_channels * input_depth * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdxBackward< + T><<(context) + .stream()>>>( + nthreads, input_grad_data, output_grad_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h new file mode 100644 index 0000000000..3a05cd98fe --- /dev/null +++ b/paddle/operators/math/pooling.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { +////////////////////// +#define FLT_MAX __FLT_MAX__ +///////////////////// + +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc new file mode 100644 index 0000000000..d7a07a403d --- /dev/null +++ b/paddle/operators/pool_with_index_op.cc @@ -0,0 +1,198 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace paddle { +namespace operators { + +int OutputSizeMaxPool(int input_size, int filter_size, int padding, + int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class MaxPoolWithIndexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mask"), + "Out(Output) of Pooling should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Pooling intput should be 4-D or 5-D"); + + if (ctx->Attrs().Get("globalPooling")) { + ksize.resize(static_cast(in_x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x_dims[i + 2]); + } + + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, + "Pooling intput size and pooling size should be consistent"); + PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, + "Pooling size size should be 2 elements. or 3 elements."); + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "strides size and pooling size should be the same."); + PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), + "paddings size and pooling size should be the same."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); + } +}; + +class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")), + "X(Input) of MaxPoolWithIndexOpGrad should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput(framework::GradVarName("X")), + "X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool2dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCHW."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCHW."); + + AddAttr>( + "ksize", "pooling size(height, width) of pooling operator."); + AddAttr( + "globalPooling", + "whether to use the globalPooling." + "int constant equal to false or true" + "default false" + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "strides(height, width) of pooling operator." + "default {1,1}") + .SetDefault({1, 1}); + AddAttr>("paddings", + "paddings(height, width) of pooling operator." + "default {0,0}") + .SetDefault({0, 0}); + + AddComment(R"DOC( +The maxPooling2d with index operation calculates the output and the mask based on +the input and ksize, strides, paddings parameters. +)DOC"); + } +}; + +class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool3dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and width of " + "image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCDHW."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCDHW."); + + AddAttr>( + "ksize", "pooling size(depth, height, width) of pooling operator."); + AddAttr( + "globalPooling", + "whether to use the globalPooling." + "int constant equal to false or true" + "default false" + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>( + "strides", + "strides(depth, height, width) of pooling operator." + "default {1,1,1}") + .SetDefault({1, 1, 1}); + AddAttr>( + "paddings", + "paddings(depth, height, width) of pooling operator." + "default {0,0,0}") + .SetDefault({0, 0, 0}); + AddComment(R"DOC( +The maxpooling3d with index operation calculates the output and the mask based on +the input and ksize, strides, paddings parameters. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(maxPool2dWithIndex, ops::MaxPoolWithIndexOp, + ops::MaxPool2dWithIndexOpMaker, maxPool2dWithIndex_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + maxPool2dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + maxPool2dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP(maxPool3dWithIndex, ops::MaxPoolWithIndexOp, + ops::MaxPool3dWithIndexOpMaker, maxPool3dWithIndex_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + maxPool3dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + maxPool3dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu new file mode 100644 index 0000000000..8007fc7ccf --- /dev/null +++ b/paddle/operators/pool_with_index_op.cu @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + maxPool2dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + maxPool2dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP_GPU_KERNEL( + maxPool3dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + maxPool3dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h new file mode 100644 index 0000000000..91abeed016 --- /dev/null +++ b/paddle/operators/pool_with_index_op.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxPoolWithIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + Tensor* mask = context.Output("Mask"); + + bool global_pooling = context.Attr("globalPooling"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (global_pooling) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + } + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexFunctor + pool2d_forward; + pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexFunctor + pool3d_forward; + pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + } + } +}; + +template +class MaxPoolWithIndexGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* mask = context.Input("Maks"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexGradFunctor + pool2d_backward; + pool2d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexGradFunctor + pool3d_backward; + pool3d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py new file mode 100644 index 0000000000..2945c8b7a4 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -0,0 +1,125 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, D, H, W = x.shape + if global_pool == 1: + ksize = [D, H, W] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + out = np.zeros((N, C, D_out, H_out, W_out)) + mask = np.zeros((N, C, D_out, H_out, W_out)) + for k in xrange(D_out): + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + for i in xrange(H_out): + h_start = np.max((i * strides[0] - paddings[0], 0)) + h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + for j in xrange(W_out): + w_start = np.max((j * strides[1] - paddings[1], 0)) + w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + # mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4)) + return out + + +def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + mask = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + # mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3)) + + return out + + +class TestMaxPoolWithIndex_Op(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = "maxPool3dWithIndex" + input = np.random.random(self.shape).astype("float32") + output = self.pool_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) + # mask = np.zeros(output.shape) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'globalPooling': self.global_pool, + } + + self.inputs = {'X': input} + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) + + def initTestCase(self): + self.global_pool = 0 + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +"""" +class TestCase1(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = 1 + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [0, 0, 0] + + +class TestCase2(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = 0 + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase3(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = 1 + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + +if __name__ == '__main__': + unittest.main() +""" From 884e31a59b72856ea1a807561f01a623c1138053 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 29 Sep 2017 15:28:25 +0800 Subject: [PATCH 02/26] add interpolation op --- paddle/operators/interp_op.cc | 107 ++++++++++++++++++ .../v2/framework/tests/test_interp_op.py | 28 +++++ 2 files changed, 135 insertions(+) create mode 100644 paddle/operators/interp_op.cc create mode 100644 python/paddle/v2/framework/tests/test_interp_op.py diff --git a/paddle/operators/interp_op.cc b/paddle/operators/interp_op.cc new file mode 100644 index 0000000000..04bcb9ade8 --- /dev/null +++ b/paddle/operators/interp_op.cc @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class InterpOp : public NetOp { + public: + InterpOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, + "Input(X) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Input("Y"), framework::kEmptyVarName, + "Input(Y) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName, + "Input(W) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("MinusOut"), framework::kEmptyVarName, + "Output(MinusOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName, + "Output(MulOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, + "Output(Out) of InterpOp should not be null."); + + // MinusOut = X - Y + auto x = Input("X"); + auto y = Input("Y"); + auto minus_out = Output("MinusOut"); + AppendOp(framework::OpRegistry::CreateOp("elementwise_sub", + {{"X", {x}}, {"Y", {y}}}, + {{"Out", {minus_out}}}, {})); + + // MulOut = MinusOut * W = (X - Y) * W + auto w = Input("W"); + auto mul_out = Output("MulOut"); + AppendOp(framework::OpRegistry::CreateOp( + "elementwise_mul", {{"X", {minus_out}}, {"Y", {w}}}, + {{"Out", {mul_out}}}, {{"axis", 0}})); + + // Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W) + AppendOp(framework::OpRegistry::CreateOp("elementwise_add", + {{"X", {mul_out}}, {"Y", {y}}}, + {{"Out", {Output("Out")}}}, {})); + + CompleteAddOp(false); + LOG(INFO) << DebugString(); + } +}; + +class InterpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "A 2-D Tensor, the first input of interp_op"); + AddInput("Y", "A 2-D Tensor, the second input of interp_op"); + AddInput("W", "A 1-D Tensor, the interpolated values"); + AddOutput("MinusOut", + "A 2-D Tensor, the intermediate outputs, saving X - Y.") + .AsIntermediate(); + AddOutput("MulOut", + "A 2-D Tensor, the intermediate outputs," + "saving the mul mul of (X - Y) and W") + .AsIntermediate(); + AddOutput("Out", + "A 2-D Tensor, the output of interp_op, same shape with X"); + AddComment(R"DOC( + Linear Interpolation with two inputs, used in NEURAL TURING MACHINE. + + Equation: + Out.row[i] = X.row[i] * W[i] + Y.row[i] * (1 - W[i]) + = (X.row[i] - Y.row[i]) * W[i] + Y.row[i] + + Example: + X = [[1,2],[3,4]], + Y = [[2,1],[4,3]], + W = [0.3, 0.4] + + Then, Out = [[1.7,1.3],[3.6,3.4]] + + where 1.7 = 1*0.3+2*(1-0.3), + 1.3 = 2*0.3+1*(1-0.3), + 3.6 = 3*0.4+4*(1-0.4), + 3.4 = 4*0.4+3*(1-0.4) +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(interp, ops::InterpOp, ops::InterpOpMaker); diff --git a/python/paddle/v2/framework/tests/test_interp_op.py b/python/paddle/v2/framework/tests/test_interp_op.py new file mode 100644 index 0000000000..f82dcc7f50 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_interp_op.py @@ -0,0 +1,28 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestInterpOp(OpTest): + def setUp(self): + self.op_type = "interp" + x = np.random.random((2, 3)).astype("float32") + y = np.random.random((2, 3)).astype("float32") + w = np.random.random(2).astype("float32") + + minus_out = x - y + mul_out = minus_out * w.reshape(2, 1) + out = mul_out + y + + self.inputs = {'X': x, 'Y': y, 'W': w} + self.outputs = {'Out': out, 'MinusOut': minus_out, 'MulOut': mul_out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +if __name__ == "__main__": + unittest.main() From a815d6abcf49d4778d0a49c852c45264bd8a684a Mon Sep 17 00:00:00 2001 From: zhouxiao-coder Date: Fri, 29 Sep 2017 17:29:52 +0800 Subject: [PATCH 03/26] elu: Optimize gradient calculation;Add more comments --- paddle/operators/activation_op.cc | 25 ++++++++++++ paddle/operators/activation_op.cu | 4 ++ paddle/operators/activation_op.h | 40 +++++++++++++++++++ .../v2/framework/tests/test_activation_op.py | 20 ++++++++++ 4 files changed, 89 insertions(+) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 1e1d3cf7f7..e83666c9f9 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -174,6 +174,25 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class ELUOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "Input of ELU operator, it shouldn't be empty. Input is flattened " + "and treated as a 1D array."); + AddOutput("Y", "Output of ELU operator, has same shape as the input."); + AddComment( + "ELU activation operator. It applies this element-wise computation on " + "the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1))." + "Check .. _Link: https://arxiv.org/abs/1511.07289 for more details"); + AddAttr("alpha", + "alpha value in the elu formulation, default to 1.") + .SetDefault(static_cast(1.)); + } +}; + template class PowOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -311,6 +330,12 @@ REGISTER_OP_CPU_KERNEL(soft_relu, REGISTER_OP_CPU_KERNEL( soft_relu_grad, ops::SoftReluGradKernel); +REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker, elu_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(elu, ops::ELUKernel); +REGISTER_OP_CPU_KERNEL(elu_grad, + ops::ELUGradKernel); + REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, ops::ActivationOpGrad); REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel); diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 56886d8b1b..48800b11ec 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -97,6 +97,10 @@ REGISTER_OP_GPU_KERNEL(soft_relu, REGISTER_OP_GPU_KERNEL( soft_relu_grad, ops::SoftReluGradKernel); +REGISTER_OP_GPU_KERNEL(elu, ops::ELUKernel); +REGISTER_OP_GPU_KERNEL(elu_grad, + ops::ELUGradKernel); + REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel); REGISTER_OP_GPU_KERNEL(pow_grad, ops::PowGradKernel); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index b9f52e1af3..3428aca817 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -296,6 +296,46 @@ class SoftReluGradKernel : public framework::OpKernel { } }; +template +class ELUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto alpha = static_cast(context.Attr("alpha")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + y.device(place) = + x.cwiseMax(static_cast(0)) + + (alpha * (x.exp() - static_cast(1))).cwiseMin(static_cast(0)); + } +}; + +template +class ELUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto alpha = static_cast(context.Attr("alpha")); + dX->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto dy = framework::EigenVector::Flatten(*dY); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + dx.device(place) = + dy * (x > static_cast(0)).template cast() + + dy * (y + alpha) * (x < static_cast(0)).template cast(); + } +}; + template class PowKernel : public framework::OpKernel { public: diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index c44eb84906..9ea01d43c5 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -144,6 +144,26 @@ class TestSoftRelu(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.02) +class TestELU(OpTest): + def setUp(self): + self.op_type = "elu" + x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + alpha = 1. + # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) + # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here + self.inputs = {'X': x} + self.attrs = {'alpha': alpha} + self.outputs = { + 'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + class TestReciprocal(OpTest): def setUp(self): self.op_type = "reciprocal" From 4436ba0c56d105b0c1305a739158fdc08258f7a9 Mon Sep 17 00:00:00 2001 From: zhouxiao-coder Date: Fri, 29 Sep 2017 17:52:18 +0800 Subject: [PATCH 04/26] elu: Optimize gradient calculation;Add more comments --- paddle/operators/activation_op.cc | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index e83666c9f9..7d086ac5df 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -180,16 +180,18 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker { ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "Input of ELU operator, it shouldn't be empty. Input is flattened " - "and treated as a 1D array."); - AddOutput("Y", "Output of ELU operator, has same shape as the input."); - AddComment( - "ELU activation operator. It applies this element-wise computation on " - "the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1))." - "Check .. _Link: https://arxiv.org/abs/1511.07289 for more details"); - AddAttr("alpha", - "alpha value in the elu formulation, default to 1.") + "(Tensor) The input of ELU operator, it shouldn't be empty. Input " + "is flattened and treated as a 1D array."); + AddOutput("Y", + "(Tensor) The output of ELU operator. It has the same shape as " + "the input."); + AddAttr( + "alpha", "(float, default 1.0) Alpha value in the elu formulation.") .SetDefault(static_cast(1.)); + AddComment(R"DOC( + ELU activation operator. It applies this element-wise computation on + the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)). + Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC"); } }; From bee95fc8917e09f61ba46586a94d2b9003cddf13 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 29 Sep 2017 11:45:04 +0800 Subject: [PATCH 05/26] fix code format and some bug --- paddle/operators/math/pooling.cc | 20 +-- paddle/operators/math/pooling.cu | 147 ++++++++++-------- paddle/operators/math/pooling.h | 1 - paddle/operators/pool_with_index_op.cc | 71 +++++---- paddle/operators/pool_with_index_op.h | 10 +- .../v2/framework/tests/test_pool_max_op.py | 52 +++++-- 6 files changed, 180 insertions(+), 121 deletions(-) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 0e4d9007a6..da0e8ff3d2 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -26,7 +26,6 @@ class MaxPool2dWithIndexFunctor { framework::Tensor& mask, std::vector& ksize, std::vector& strides, std::vector& paddings) { const int batch_size = input.dims()[0]; - const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; const int output_channels = output.dims()[1]; @@ -112,11 +111,11 @@ class MaxPool2dWithIndexGradFunctor { input_grad_data[input_idx] += output_grad_data[output_idx]; } } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; } - // offset - input_grad_data += input_stride; - output_grad_data += output_stride; - mask_data += output_stride; } } }; @@ -152,6 +151,7 @@ class MaxPool3dWithIndexFunctor { const int padding_width = paddings[2]; const int input_stride = input_depth * input_height * input_width; const int output_stride = output_depth * output_height * output_width; + const T* input_data = input.data(); T* output_data = output.mutable_data(context.GetPlace()); T* mask_data = mask.mutable_data(context.GetPlace()); @@ -170,17 +170,17 @@ class MaxPool3dWithIndexFunctor { int wstart = pw * stride_width - padding_width; int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); + int output_idx = (pd * output_height + ph) * output_width + pw; T ele = static_cast(-FLT_MAX); int index = -1; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (ele < - input_data[(d * input_height + h) * input_width + w]) { - index = (d * input_height + h) * input_width + w; - ele = - input_data[(d * input_height + h) * input_width + w]; + int input_idx = (d * input_height + h) * input_width + w; + if (ele < input_data[input_idx]) { + index = input_idx; + ele = input_data[input_idx]; } } } diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index f32e6a26d0..5321ed2163 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -20,14 +20,14 @@ namespace operators { namespace math { template -__global__ void KernelMaxPool2dWithIdxForward( +__global__ void KernelMaxPool2dWithIdx( const int nthreads, const T* input_data, T* output_data, T* mask_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < nthreads) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; int c = (index / output_width / output_height) % channels; @@ -43,51 +43,58 @@ __global__ void KernelMaxPool2dWithIdxForward( input_data += (batch_idx * channels + c) * input_height * input_width; T ele = -FLT_MAX; - int index = -1; + int max_index = -1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (ele < input_data[h * input_width + w]) { - index = h * input_width + w; - ele = input_data[h * input_width + w]; + int input_index = h * input_width + w; + if (ele < input_data[input_index]) { + max_index = input_index; + ele = input_data[input_index]; } } } output_data[index] = ele; - mask_data[index] = index; + mask_data[index] = max_index; } } template -__global__ void KernelMaxPool2DWithIdxBackward( +__global__ void KernelMaxPool2DWithIdxGrad( const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < nthreads) { - int offsetW = index % input_width + padding_width; - int offsetH = (index / input_width) % input_height + padding_height; - int offsetC = (index / input_width / input_height) % channels; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int c_offset = (index / input_width / input_height) % channels; int batch_idx = index / input_width / input_height / channels; - int phstart = (offsetH < ksize_height) - ? 0 - : (offsetH - ksize_height) / stride_height + 1; - int pwstart = (offsetW < ksize_width) - ? 0 - : (offsetW - ksize_width) / stride_width + 1; - int phend = min(offsetH / stride_height + 1, output_height); - int pwend = min(offsetW / stride_width + 1, output_width); + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + T gradient = 0; + int input_current_featuremap_idx = h_offset * input_width + w_offset; int output_idx = - (batch_idx * channels + offsetC) * output_height * output_width; + (batch_idx * channels + c_offset) * output_height * output_width; + mask_data += output_idx; output_grad += output_idx; - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - if ((offsetH * input_width + offsetW) == - mask_data[ph * output_width + pw]) + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask_data[ph * output_width + pw] == input_current_featuremap_idx) gradient += output_grad[ph * output_width + pw]; } } @@ -125,7 +132,7 @@ class MaxPool2dWithIndexFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool2dWithIdxForward< + KernelMaxPool2dWithIdx< T><<(context) .stream()>>>(nthreads, input_data, output_data, mask_data, @@ -167,7 +174,7 @@ class MaxPool2dWithIndexGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool2DWithIdxBackward< + KernelMaxPool2DWithIdxGrad< T><<(context) .stream()>>>(nthreads, input_grad_data, output_grad_data, @@ -184,7 +191,7 @@ template class MaxPool2dWithIndexFunctor; template class MaxPool2dWithIndexGradFunctor; template -__global__ void KernelMaxPool3DWithIdxForward( +__global__ void KernelMaxPool3DWithIdx( const int nthreads, const T* input_data, T* output_data, T* mask_data, const int channels, const int input_depth, const int input_height, const int input_width, const int output_depth, const int output_height, @@ -200,6 +207,7 @@ __global__ void KernelMaxPool3DWithIdxForward( int c = (index / output_width / output_height / output_depth) % channels; int batch_idx = index / output_width / output_height / output_depth / channels; + int dstart = pd * stride_depth - padding_depth; int hstart = ph * stride_height - padding_height; int wstart = pw * stride_width - padding_width; @@ -209,8 +217,9 @@ __global__ void KernelMaxPool3DWithIdxForward( dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); + T ele = -FLT_MAX; - int index = -1; + int max_index = -1; input_data += (batch_idx * channels + c) * input_depth * input_height * input_width; @@ -218,19 +227,19 @@ __global__ void KernelMaxPool3DWithIdxForward( for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { if (ele < input_data[(d * input_height + h) * input_width + w]) { - index = (d * input_height + h) * input_width + w; - ele = input_data[(d * input_height + h) * input_width + w]; + max_index = (d * input_height + h) * input_width + w; + ele = input_data[max_index]; } } } } output_data[index] = ele; - mask_data[index] = index; + mask_data[index] = max_index; } } template -__global__ void KernelMaxPool3DWithIdxBackward( +__global__ void KernelMaxPool3DWithIdxGrad( const int nthreads, T* input_grad, const T* output_grad, const T* mask, const int channels, const int input_depth, const int input_height, const int input_width, const int output_depth, const int output_height, @@ -240,37 +249,45 @@ __global__ void KernelMaxPool3DWithIdxBackward( const int padding_width) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); index += blockDim.x * gridDim.x) { - int offsetW = index % input_width + padding_width; - int offsetH = (index / input_width) % input_height + padding_height; - int offsetD = - (index / input_width / input_height) % input_depth + padding_depth; - int offsetC = (index / input_width / input_height / input_depth) % channels; + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int d_offset = (index / input_width / input_height) % input_depth; + int c_offset = + (index / input_width / input_height / input_depth) % channels; int batch_idx = index / input_width / input_height / input_depth / channels; - int pdstart = (offsetD < ksize_depth) - ? 0 - : (offsetD - ksize_depth) / stride_depth + 1; - int phstart = (offsetH < ksize_height) - ? 0 - : (offsetH - ksize_height) / stride_height + 1; - int pwstart = (offsetW < ksize_width) - ? 0 - : (offsetW - ksize_width) / stride_width + 1; - int pdend = min((offsetD) / stride_depth + 1, output_depth); - int phend = min((offsetH) / stride_height + 1, output_height); - int pwend = min((offsetW) / stride_width + 1, output_width); + int pd_start = + (d_offset + padding_depth < ksize_depth) + ? 0 + : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int pd_end = + min((d_offset + padding_depth) / stride_depth + 1, output_depth); + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); T gradient = 0; - int output_idx = (batch_idx * channels + offsetC) * output_depth * + int input_current_feature_map_idx = + (d_offset * input_height + h_offset) * input_width + w_offset; + int output_idx = (batch_idx * channels + c_offset) * output_depth * output_height * output_width; mask += output_idx; output_grad += output_idx; - for (int pd = pdstart; pd < pdend; ++pd) { - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - if (((offsetD * input_height + offsetH) * input_width + offsetW) == - mask[(pd * output_height + ph) * output_width + pw]) + for (int pd = pd_start; pd < pd_end; ++pd) { + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask[(pd * output_height + ph) * output_width + pw] == + input_current_feature_map_idx) gradient += output_grad[(pd * output_height + ph) * output_width + pw]; } @@ -308,7 +325,7 @@ class MaxPool3dWithIndexFunctor { const T* input_data = input.data(); T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -316,7 +333,7 @@ class MaxPool3dWithIndexFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool3DWithIdxForward< + KernelMaxPool3DWithIdx< T><<(context) .stream()>>>( @@ -341,10 +358,10 @@ class MaxPool3dWithIndexGradFunctor { const int input_depth = input_grad.dims()[2]; const int input_height = input_grad.dims()[3]; const int input_width = input_grad.dims()[4]; - const int output_channels = input_grad.dims()[1]; - const int output_depth = input_grad.dims()[2]; - const int output_height = input_grad.dims()[3]; - const int output_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -365,7 +382,7 @@ class MaxPool3dWithIndexGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool3DWithIdxBackward< + KernelMaxPool3DWithIdxGrad< T><<(context) .stream()>>>( diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index 3a05cd98fe..308a9341b6 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -23,7 +23,6 @@ namespace operators { namespace math { ////////////////////// #define FLT_MAX __FLT_MAX__ -///////////////////// template class MaxPool2dWithIndexFunctor { diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index d7a07a403d..c51145b923 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -76,8 +76,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")), - "X(Input) of MaxPoolWithIndexOpGrad should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); PADDLE_ENFORCE( ctx->HasOutput(framework::GradVarName("X")), "X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); @@ -97,28 +97,37 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { "number of channels, H and W is the height and width of image."); AddOutput("Out", "The output tensor of pooling operator." - "The format of output tensor is also NCHW."); + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of image."); AddOutput("Mask", "The Mask tensor of pooling operator." - "The format of output tensor is also NCHW."); + "The format of output tensor is also NCHW." + "Where N is batch size, C is the number of channels, H and W " + "is the height and width of image." + "The value in it is the index in current feature map"); AddAttr>( - "ksize", "pooling size(height, width) of pooling operator."); + "ksize", + "Pooling size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) AddAttr( "globalPooling", - "whether to use the globalPooling." - "int constant equal to false or true" - "default false" + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." "If globalPooling = true, ksize is ignored and need not be specified.") .SetDefault(false); AddAttr>("strides", - "strides(height, width) of pooling operator." - "default {1,1}") - .SetDefault({1, 1}); + "Strides(height, width) of pooling operator." + "Default {1,1}.") + .SetDefault({1, 1}); // TODO(Add checker) AddAttr>("paddings", - "paddings(height, width) of pooling operator." - "default {0,0}") - .SetDefault({0, 0}); + "Paddings(height, width) of pooling operator." + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Add checker) AddComment(R"DOC( The maxPooling2d with index operation calculates the output and the mask based on @@ -140,30 +149,40 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { "image."); AddOutput("Out", "The output tensor of pooling operator." - "The format of output tensor is also NCDHW."); + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of image."); AddOutput("Mask", "The Mask tensor of pooling operator." - "The format of output tensor is also NCDHW."); + "The format of output tensor is also NCDHW." + "Where N is batch size, C is the number of channels, D, H and W " + "is the depth, height and width of image." + "The value in it is the index in current feature map"); AddAttr>( - "ksize", "pooling size(depth, height, width) of pooling operator."); + "ksize", + "Pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) AddAttr( "globalPooling", - "whether to use the globalPooling." - "int constant equal to false or true" - "default false" + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." "If globalPooling = true, ksize is ignored and need not be specified.") .SetDefault(false); AddAttr>( "strides", - "strides(depth, height, width) of pooling operator." - "default {1,1,1}") - .SetDefault({1, 1, 1}); + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Add checker) AddAttr>( "paddings", - "paddings(depth, height, width) of pooling operator." - "default {0,0,0}") - .SetDefault({0, 0, 0}); + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Add checker) + AddComment(R"DOC( The maxpooling3d with index operation calculates the output and the mask based on the input and ksize, strides, paddings parameters. diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index 91abeed016..5fe2f5df93 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -32,11 +32,10 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { Tensor* out = context.Output("Out"); Tensor* mask = context.Output("Mask"); - bool global_pooling = context.Attr("globalPooling"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (global_pooling) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -63,7 +62,7 @@ template class MaxPoolWithIndexGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - const Tensor* mask = context.Input("Maks"); + const Tensor* mask = context.Input("Mask"); const Tensor* out_grad = context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); @@ -71,6 +70,11 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x_grad->dims()[i + 2]); + } + } if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index 2945c8b7a4..ffc345198d 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -3,7 +3,11 @@ import numpy as np from op_test import OpTest -def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): +def max_pool3D_forward_naive(x, + ksize, + strides, + paddings=[0, 0, 0], + global_pool=0): N, C, D, H, W = x.shape if global_pool == 1: @@ -25,8 +29,19 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) - # mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4)) - return out + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :, :] + index = np.where(arr == np.max(arr)) + sub_deep = index[0][0] + sub_row = index[1][0] + sub_col = index[2][0] + index = ((d_start + sub_deep) * H + + (h_start + sub_row)) * W + w_start + sub_col + mask[n, c, k, i, j] = index + + return out, mask def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): @@ -47,19 +62,25 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): x_masked = x[:, :, r_start:r_end, c_start:c_end] out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) - # mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3)) - return out + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :] + index = np.where(arr == np.max(arr)) + sub_row = index[0][0] + sub_col = index[1][0] + index = (r_start + sub_row) * W + c_start + sub_col + mask[n, c, i, j] = index + + return out, mask class TestMaxPoolWithIndex_Op(OpTest): def setUp(self): self.initTestCase() - self.op_type = "maxPool3dWithIndex" input = np.random.random(self.shape).astype("float32") - output = self.pool_forward_naive(input, self.ksize, self.strides, - self.paddings, self.global_pool) - # mask = np.zeros(output.shape) + output, mask = self.pool_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) self.attrs = { 'strides': self.strides, @@ -69,7 +90,7 @@ class TestMaxPoolWithIndex_Op(OpTest): } self.inputs = {'X': input} - self.outputs = {'Out': output} + self.outputs = {'Out': output, "Mask": mask} def test_check_output(self): self.check_output() @@ -78,7 +99,8 @@ class TestMaxPoolWithIndex_Op(OpTest): # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) def initTestCase(self): - self.global_pool = 0 + self.global_pool = False + self.op_type = "maxPool3dWithIndex" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 7, 7, 7] self.ksize = [3, 3, 3] @@ -86,10 +108,9 @@ class TestMaxPoolWithIndex_Op(OpTest): self.paddings = [1, 1, 1] -"""" class TestCase1(TestMaxPoolWithIndex_Op): def initTestCase(self): - self.global_pool = 1 + self.global_pool = True self.op_type = "maxPool3dWithIndex" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] @@ -100,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op): class TestCase2(TestMaxPoolWithIndex_Op): def initTestCase(self): - self.global_pool = 0 + self.global_pool = False self.op_type = "maxPool2dWithIndex" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] @@ -111,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op): class TestCase3(TestMaxPoolWithIndex_Op): def initTestCase(self): - self.global_pool = 1 + self.global_pool = True self.op_type = "maxPool2dWithIndex" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 5, 5] @@ -122,4 +143,3 @@ class TestCase3(TestMaxPoolWithIndex_Op): if __name__ == '__main__': unittest.main() -""" From 6fc44800ed800e78822c6af5750e202c041d9173 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 30 Sep 2017 15:20:44 +0800 Subject: [PATCH 06/26] fix unit test --- .../v2/framework/tests/test_pool_max_op.py | 72 ++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index ffc345198d..17028c3bf6 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -98,6 +98,28 @@ class TestMaxPoolWithIndex_Op(OpTest): # def test_check_grad(self): # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase1(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase2(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False self.op_type = "maxPool3dWithIndex" @@ -108,7 +130,18 @@ class TestMaxPoolWithIndex_Op(OpTest): self.paddings = [1, 1, 1] -class TestCase1(TestMaxPoolWithIndex_Op): +class TestCase3(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase4(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True self.op_type = "maxPool3dWithIndex" @@ -116,10 +149,21 @@ class TestCase1(TestMaxPoolWithIndex_Op): self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase5(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] self.paddings = [0, 0, 0] -class TestCase2(TestMaxPoolWithIndex_Op): +class TestCase6(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False self.op_type = "maxPool2dWithIndex" @@ -130,7 +174,18 @@ class TestCase2(TestMaxPoolWithIndex_Op): self.paddings = [1, 1] -class TestCase3(TestMaxPoolWithIndex_Op): +class TestCase7(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +class TestCase8(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True self.op_type = "maxPool2dWithIndex" @@ -138,6 +193,17 @@ class TestCase3(TestMaxPoolWithIndex_Op): self.shape = [2, 3, 5, 5] self.ksize = [3, 3] self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase9(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] self.paddings = [0, 0] From bb33c2b3a5f2f2aa50e9ceaf1f2120edfe76455b Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 30 Sep 2017 15:44:02 +0800 Subject: [PATCH 07/26] fix kernel func --- paddle/operators/math/pooling.cc | 227 ++++++++++++++++++++++++++ paddle/operators/math/pooling.cu | 10 +- paddle/operators/math/pooling.h | 37 +++++ paddle/operators/pool_with_index_op.h | 4 +- 4 files changed, 270 insertions(+), 8 deletions(-) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 3b706529d8..5accde8b07 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -458,6 +458,233 @@ template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; + +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + ele = input_data[h * input_width + w]; + index = h * input_width + w; + } + } + } + output_data[ph * output_width + pw] = ele; + mask_data[ph * output_width + pw] = index; + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = ph * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + int output_idx = (pd * output_height + ph) * output_width + pw; + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_idx = (d * input_height + h) * input_width + w; + if (ele < input_data[input_idx]) { + index = input_idx; + ele = input_data[input_idx]; + } + } + } + } + output_data[output_idx] = ele; + mask_data[output_idx] = index; + } + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = + (pd * output_height + ph) * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 6aafedf912..06263737a9 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -637,7 +637,7 @@ __global__ void KernelMaxPool2dWithIdx( const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; @@ -676,7 +676,7 @@ __global__ void KernelMaxPool2DWithIdxGrad( const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; int h_offset = (index / input_width) % input_height; @@ -766,7 +766,6 @@ class MaxPool2dWithIndexGradFunctor { const int input_channels = input_grad.dims()[1]; const int input_height = input_grad.dims()[2]; const int input_width = input_grad.dims()[3]; - const int output_channels = output_grad.dims()[1]; const int output_height = output_grad.dims()[2]; const int output_width = output_grad.dims()[3]; const int ksize_height = ksize[0]; @@ -810,7 +809,7 @@ __global__ void KernelMaxPool3DWithIdx( const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; @@ -858,7 +857,7 @@ __global__ void KernelMaxPool3DWithIdxGrad( const int ksize_width, const int stride_depth, const int stride_height, const int stride_width, const int padding_depth, const int padding_height, const int padding_width) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; int h_offset = (index / input_width) % input_height; @@ -969,7 +968,6 @@ class MaxPool3dWithIndexGradFunctor { const int input_depth = input_grad.dims()[2]; const int input_height = input_grad.dims()[3]; const int input_width = input_grad.dims()[4]; - const int output_channels = output_grad.dims()[1]; const int output_depth = output_grad.dims()[2]; const int output_height = output_grad.dims()[3]; const int output_width = output_grad.dims()[4]; diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index d214c68923..d819e5986e 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -117,6 +117,43 @@ class MaxPool3dGradFunctor { std::vector& strides, std::vector& paddings); }; +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index 5fe2f5df93..01b961ca82 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -25,7 +25,7 @@ namespace operators { using Tensor = framework::Tensor; template -class MaxPoolWithIndexKernel : public framework::OpKernel { +class MaxPoolWithIndexKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* in_x = context.Input("X"); @@ -59,7 +59,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { }; template -class MaxPoolWithIndexGradKernel : public framework::OpKernel { +class MaxPoolWithIndexGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* mask = context.Input("Mask"); From 5b862fedf1feb78b7dc63451e5219cef2fde33a3 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 9 Oct 2017 11:31:13 +0800 Subject: [PATCH 08/26] remove debug log in interp_op.cc --- paddle/operators/interp_op.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/operators/interp_op.cc b/paddle/operators/interp_op.cc index 04bcb9ade8..fc8b9a11b8 100644 --- a/paddle/operators/interp_op.cc +++ b/paddle/operators/interp_op.cc @@ -58,7 +58,6 @@ class InterpOp : public NetOp { {{"Out", {Output("Out")}}}, {})); CompleteAddOp(false); - LOG(INFO) << DebugString(); } }; From fcfce48421650f983b484af9fe20d2e843dc042b Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 9 Oct 2017 19:02:24 +0800 Subject: [PATCH 09/26] follow coments --- paddle/operators/CMakeLists.txt | 3 +- paddle/operators/math/pooling.h | 42 +++++++++++++++++-- paddle/operators/pool_with_index_op.cc | 20 ++++----- paddle/operators/pool_with_index_op.cu | 8 ++-- .../v2/framework/tests/test_pool_max_op.py | 21 +++++----- 5 files changed, 65 insertions(+), 29 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 49da132049..39af318ca5 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -75,10 +75,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") endif() + # pool_with_index_op contains several operators if ("${TARGET}" STREQUAL "pool_with_index_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(maxPool2dWithIndex);\n") + file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() # pybind USE_NO_KERNEL_OP diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index d819e5986e..f15ddca69a 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -21,15 +21,26 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { -////////////////////// -#define FLT_MAX __FLT_MAX__ // +#define FLT_MAX \ + __FLT_MAX__ // It might need to be placed in another file, but I'm still + // wondering where to put it + +/* + * \brief Extracting simple operations from pooling. + * Both MaxPool and AvgPool need initial, compute and finalize operation. + * MaxPool initializes temp variable to the negative maximum to find the + * maximum value in the pooling field. + * AvgPool initializes temp variable to the zero to accumulate all values + * in pool pooling, and takes the average. + * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. + */ template class MaxPool { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } - DEVICE inline void finalize(T& y, const T& poo_size) {} + DEVICE inline void finalize(T& y, const T& pool_field) {} }; template @@ -37,8 +48,9 @@ class AvgPool { public: DEVICE inline T initial() { return static_cast(0); } DEVICE inline void compute(T& y, const T& x) { y += x; } - DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; } + DEVICE inline void finalize(T& y, const T& pool_field) { y /= pool_field; } }; + template class MaxPoolGrad { public: @@ -57,6 +69,20 @@ class AvgPoolGrad { } }; +/* + * \brief Getting pooling results, and calculating gradient. + * + * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in + * NCDHW format. + * + * In max pooling, it is possible that the pooling region has multiple maximum + * elements. + * In this case, we should compute the gradient of the first maximum element. + * This is different from average pooling. So we rewrite the max_pool_grad: + * MaxPool2dGradFunctor, MaxPool3dGradFunctor. + * + */ + template class Pool2dFunctor { public: @@ -117,6 +143,14 @@ class MaxPool3dGradFunctor { std::vector& strides, std::vector& paddings); }; +/* + * \brief Getting max pooling results and corresponding max index, and + * calculating gradient. + * In sub-sampling-pooling, it is necessary to know max element index. + * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in + * NCDHW format. + * + */ template class MaxPool2dWithIndexFunctor { public: diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index c51145b923..2e6a5f2555 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -17,8 +17,8 @@ limitations under the License. */ namespace paddle { namespace operators { -int OutputSizeMaxPool(int input_size, int filter_size, int padding, - int stride) { +inline int OutputSizeMaxPool(int input_size, int filter_size, int padding, + int stride) { int output_size = (input_size - filter_size + 2 * padding) / stride + 1; return output_size; } @@ -194,24 +194,24 @@ the input and ksize, strides, paddings parameters. namespace ops = paddle::operators; -REGISTER_OP(maxPool2dWithIndex, ops::MaxPoolWithIndexOp, - ops::MaxPool2dWithIndexOpMaker, maxPool2dWithIndex_grad, +REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool2dWithIndexOpMaker, max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad); REGISTER_OP_CPU_KERNEL( - maxPool2dWithIndex, + max_pool2d_with_index, ops::MaxPoolWithIndexKernel); REGISTER_OP_CPU_KERNEL( - maxPool2dWithIndex_grad, + max_pool2d_with_index_grad, ops::MaxPoolWithIndexGradKernel) -REGISTER_OP(maxPool3dWithIndex, ops::MaxPoolWithIndexOp, - ops::MaxPool3dWithIndexOpMaker, maxPool3dWithIndex_grad, +REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad); REGISTER_OP_CPU_KERNEL( - maxPool3dWithIndex, + max_pool3d_with_index, ops::MaxPoolWithIndexKernel); REGISTER_OP_CPU_KERNEL( - maxPool3dWithIndex_grad, + max_pool3d_with_index_grad, ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu index 8007fc7ccf..287657d4b1 100644 --- a/paddle/operators/pool_with_index_op.cu +++ b/paddle/operators/pool_with_index_op.cu @@ -17,15 +17,15 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - maxPool2dWithIndex, + max_pool2d_with_index, ops::MaxPoolWithIndexKernel); REGISTER_OP_GPU_KERNEL( - maxPool2dWithIndex_grad, + max_pool2d_with_index_grad, ops::MaxPoolWithIndexGradKernel) REGISTER_OP_GPU_KERNEL( - maxPool3dWithIndex, + max_pool3d_with_index, ops::MaxPoolWithIndexKernel); REGISTER_OP_GPU_KERNEL( - maxPool3dWithIndex_grad, + max_pool3d_with_index_grad, ops::MaxPoolWithIndexGradKernel) diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index 17028c3bf6..f0f8aa6089 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -100,7 +100,8 @@ class TestMaxPoolWithIndex_Op(OpTest): def initTestCase(self): self.global_pool = True - self.op_type = "maxPool3dWithIndex" + self.index = "max_pool3d_with_index" + self.op_type = "%s" % self.index self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] @@ -111,7 +112,7 @@ class TestMaxPoolWithIndex_Op(OpTest): class TestCase1(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True - self.op_type = "maxPool3dWithIndex" + self.op_type = "max_pool3d_with_index" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] @@ -122,7 +123,7 @@ class TestCase1(TestMaxPoolWithIndex_Op): class TestCase2(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False - self.op_type = "maxPool3dWithIndex" + self.op_type = "max_pool3d_with_index" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 7, 7, 7] self.ksize = [3, 3, 3] @@ -133,7 +134,7 @@ class TestCase2(TestMaxPoolWithIndex_Op): class TestCase3(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False - self.op_type = "maxPool3dWithIndex" + self.op_type = "max_pool3d_with_index" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 7, 7, 7] self.ksize = [3, 3, 3] @@ -144,7 +145,7 @@ class TestCase3(TestMaxPoolWithIndex_Op): class TestCase4(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True - self.op_type = "maxPool3dWithIndex" + self.op_type = "max_pool3d_with_index" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] @@ -155,7 +156,7 @@ class TestCase4(TestMaxPoolWithIndex_Op): class TestCase5(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True - self.op_type = "maxPool3dWithIndex" + self.op_type = "max_pool3d_with_index" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] self.ksize = [3, 3, 3] @@ -166,7 +167,7 @@ class TestCase5(TestMaxPoolWithIndex_Op): class TestCase6(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False - self.op_type = "maxPool2dWithIndex" + self.op_type = "max_pool2d_with_index" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] @@ -177,7 +178,7 @@ class TestCase6(TestMaxPoolWithIndex_Op): class TestCase7(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = False - self.op_type = "maxPool2dWithIndex" + self.op_type = "max_pool2d_with_index" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] @@ -188,7 +189,7 @@ class TestCase7(TestMaxPoolWithIndex_Op): class TestCase8(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True - self.op_type = "maxPool2dWithIndex" + self.op_type = "max_pool2d_with_index" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 5, 5] self.ksize = [3, 3] @@ -199,7 +200,7 @@ class TestCase8(TestMaxPoolWithIndex_Op): class TestCase9(TestMaxPoolWithIndex_Op): def initTestCase(self): self.global_pool = True - self.op_type = "maxPool2dWithIndex" + self.op_type = "max_pool2d_with_index" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 5, 5] self.ksize = [3, 3] From a06f099d9f54b47ce4df7d1ae32c928fb8d7593e Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 9 Oct 2017 16:34:05 +0800 Subject: [PATCH 10/26] refine comment of interp_op --- paddle/operators/interp_op.cc | 43 +++++++++++-------- .../v2/framework/tests/test_interp_op.py | 6 +-- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/paddle/operators/interp_op.cc b/paddle/operators/interp_op.cc index fc8b9a11b8..d02b01c3f3 100644 --- a/paddle/operators/interp_op.cc +++ b/paddle/operators/interp_op.cc @@ -30,27 +30,26 @@ class InterpOp : public NetOp { "Input(Y) of InterpOp should not be null."); PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName, "Input(W) of InterpOp should not be null."); - PADDLE_ENFORCE_NE(Output("MinusOut"), framework::kEmptyVarName, - "Output(MinusOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName, + "Output(SubOut) of InterpOp should not be null."); PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName, "Output(MulOut) of InterpOp should not be null."); PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, "Output(Out) of InterpOp should not be null."); - // MinusOut = X - Y + // SubOut = X - Y auto x = Input("X"); auto y = Input("Y"); - auto minus_out = Output("MinusOut"); - AppendOp(framework::OpRegistry::CreateOp("elementwise_sub", - {{"X", {x}}, {"Y", {y}}}, - {{"Out", {minus_out}}}, {})); + auto sub_out = Output("SubOut"); + AppendOp(framework::OpRegistry::CreateOp( + "elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {})); - // MulOut = MinusOut * W = (X - Y) * W + // MulOut = SubOut * W = (X - Y) * W auto w = Input("W"); auto mul_out = Output("MulOut"); AppendOp(framework::OpRegistry::CreateOp( - "elementwise_mul", {{"X", {minus_out}}, {"Y", {w}}}, - {{"Out", {mul_out}}}, {{"axis", 0}})); + "elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}}, + {{"axis", 0}})); // Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W) AppendOp(framework::OpRegistry::CreateOp("elementwise_add", @@ -65,18 +64,26 @@ class InterpOpMaker : public framework::OpProtoAndCheckerMaker { public: InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "A 2-D Tensor, the first input of interp_op"); - AddInput("Y", "A 2-D Tensor, the second input of interp_op"); - AddInput("W", "A 1-D Tensor, the interpolated values"); - AddOutput("MinusOut", - "A 2-D Tensor, the intermediate outputs, saving X - Y.") + AddInput("X", + "(Tensor), 2-D Matrix of shape [batch_size, data_dim]" + "containing data samples, the first input of interp_op"); + AddInput("Y", + "(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`" + "containing data samples, the second input of interp_op"); + AddInput("W", + "(Tensor), 1-D Vector of shape [batch_size]," + "the interpolated values in the half-open interval [0.0, 1.0)"); + AddOutput("SubOut", + "(Tensor), the intermediate subtraction outputs, saving X - Y.") .AsIntermediate(); AddOutput("MulOut", - "A 2-D Tensor, the intermediate outputs," - "saving the mul mul of (X - Y) and W") + "(Tensor), the intermediate multiplication outputs," + "saving the elementwise multiplication of (X - Y) and W.") .AsIntermediate(); AddOutput("Out", - "A 2-D Tensor, the output of interp_op, same shape with X"); + "(Tensor), the output of interp_op, same shape with X," + "returns the first-dimensional piecewise linear interpolant " + "between X and Y"); AddComment(R"DOC( Linear Interpolation with two inputs, used in NEURAL TURING MACHINE. diff --git a/python/paddle/v2/framework/tests/test_interp_op.py b/python/paddle/v2/framework/tests/test_interp_op.py index f82dcc7f50..066569b96c 100644 --- a/python/paddle/v2/framework/tests/test_interp_op.py +++ b/python/paddle/v2/framework/tests/test_interp_op.py @@ -10,12 +10,12 @@ class TestInterpOp(OpTest): y = np.random.random((2, 3)).astype("float32") w = np.random.random(2).astype("float32") - minus_out = x - y - mul_out = minus_out * w.reshape(2, 1) + sub_out = x - y + mul_out = sub_out * w.reshape(2, 1) out = mul_out + y self.inputs = {'X': x, 'Y': y, 'W': w} - self.outputs = {'Out': out, 'MinusOut': minus_out, 'MulOut': mul_out} + self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out} def test_check_output(self): self.check_output() From c464ec21d8b0a1e7ad6da7115b78cd047d9a2041 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 12:09:39 -0700 Subject: [PATCH 11/26] Fix bug of foward default attribute not passed to backward --- paddle/framework/backward.cc | 2 +- paddle/framework/op_desc.h | 5 +++++ paddle/framework/op_registry.cc | 11 ++++++++--- paddle/framework/op_registry.h | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index c970e01dd1..0a4688db9c 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -302,7 +302,7 @@ std::vector> MakeOpGrad( return grad_op_descs; // empty vector } - grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc); + grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get()); std::list> pending_fill_zeros_ops; for (auto& desc : grad_op_descs) { diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index b39808dad1..b729029412 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -97,6 +97,11 @@ class OpDescBind { const VariableNameMap &Outputs() const { return outputs_; } + AttributeMap *MutableAttrMap() { + this->need_update_ = true; + return &this->attrs_; + } + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 66043f6e04..b118edae17 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -60,9 +60,14 @@ std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { } std::vector> OpRegistry::CreateGradOpDescs( - const OpDescBind& op_desc) { - auto& info = OpInfoMap::Instance().Get(op_desc.Type()); - return info.grad_op_maker_(op_desc); + OpDescBind* op_desc) { + auto& info = OpInfoMap::Instance().Get(op_desc->Type()); + + if (info.Checker() != nullptr) { + info.Checker()->Check(*op_desc->MutableAttrMap()); + } + + return info.grad_op_maker_(*op_desc); } } // namespace framework diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index cce3605fd4..5ca3af52a6 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -80,7 +80,7 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); static std::vector> CreateGradOpDescs( - const OpDescBind& op_desc); + OpDescBind* op_desc); static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; From dcb09e932d57701b553a5308aaab5b16bf214910 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 14:21:58 -0700 Subject: [PATCH 12/26] Use PROTO_LITE when refactoring Paddle It will significantly reduce binary size. It is useful for mobile deployment. --- paddle/framework/framework.proto | 1 + paddle/framework/op_desc.h | 2 -- paddle/framework/program_desc.h | 2 -- paddle/operators/net_op.h | 1 + paddle/pybind/protobuf.cc | 3 --- 5 files changed, 2 insertions(+), 7 deletions(-) diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ac2827e547..b7a63f9ba1 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; +option optimize_for = LITE_RUNTIME; package paddle.framework; enum AttrType { diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index b729029412..d0c314771c 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -52,8 +52,6 @@ class OpDescBind { void SetOutput(const std::string ¶m_name, const std::vector &args); - std::string DebugString() { return this->Proto()->DebugString(); } - bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.end(); } diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 9b34a06aef..d684b08d16 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -31,8 +31,6 @@ class ProgramDescBind { BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } - std::string DebugString() { return Proto()->DebugString(); } - size_t Size() const { return blocks_.size(); } ProgramDesc *Proto(); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 2388b094d2..ebeb262d96 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/framework/framework.pb.h" #include "paddle/framework/op_registry.h" diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 218821b35b..47bd7bc3bb 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -117,7 +117,6 @@ void BindProgramDesc(py::module &m) { .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) - .def("__str__", &ProgramDescBind::DebugString) .def("num_blocks", &ProgramDescBind::Size); } @@ -191,8 +190,6 @@ void BindOpDesc(py::module &m) { .def("output", &OpDescBind::Output) .def("output_names", &OpDescBind::OutputNames) .def("set_output", &OpDescBind::SetOutput) - .def("__str__", &OpDescBind::DebugString) - .def("__repr__", &OpDescBind::DebugString) .def("has_attr", &OpDescBind::HasAttr) .def("attr_type", &OpDescBind::GetAttrType) .def("attr_names", &OpDescBind::AttrNames) From 383faaf71907605b92feef2ffc91f955623fa82d Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Mon, 9 Oct 2017 14:25:38 -0700 Subject: [PATCH 13/26] Adding implementation for copying a vector to a tensor (#4635) * Adding implementation for copying a vector to tensor * Changing Tensor test to access gpu memory indirectly --- paddle/framework/tensor.h | 13 +++++ paddle/framework/tensor_impl.h | 23 +++++++++ paddle/framework/tensor_test.cc | 87 +++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 80a3f0a393..ba82127d9c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -95,6 +95,19 @@ class Tensor { template inline void CopyFrom(const Tensor& src, const platform::Place& dst_place); + /** + * @brief Copy the content of an external vector to a tensor. + * + * @param[in] src The external vector. + * @param[in] ctx The device context contains place where to store. + * + * * @note CopyFromVector assumes that the tensor has been resized + * before invoking. + */ + template + inline void CopyFromVector(const std::vector& src, + const platform::Place& dst_place); + /** * @brief Return the slice of the tensor. * diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 379eac94f9..8ee9941982 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -123,6 +123,29 @@ inline void Tensor::CopyFrom(const Tensor& src, #endif } +template +inline void Tensor::CopyFromVector(const std::vector& src, + const platform::Place& dst_place) { + auto src_ptr = static_cast(src.data()); + platform::CPUPlace src_place; + auto dst_ptr = static_cast(mutable_data(dst_place)); + auto size = src.size() * sizeof(T); + + if (platform::is_cpu_place(dst_place)) { + memory::Copy(boost::get(dst_place), dst_ptr, src_place, + src_ptr, size); + } +#ifdef PADDLE_WITH_CUDA + else if (platform::is_gpu_place(dst_place)) { + memory::Copy(boost::get(dst_place), dst_ptr, src_place, + src_ptr, size, 0); + } + PADDLE_ENFORCE(cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in Tensor CopyFromVector"); + +#endif +} + template inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { check_memory_size(); diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 58cf0fc3cb..492eba69e1 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -263,6 +263,93 @@ TEST(Tensor, CopyFrom) { #endif } +TEST(Tensor, CopyFromVector) { + using namespace paddle::framework; + using namespace paddle::platform; + { + std::vector src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + Tensor cpu_tensor; + + // Copy to CPU Tensor + cpu_tensor.Resize(make_ddim({3, 3})); + auto cpu_place = new paddle::platform::CPUPlace(); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + + // Compare Tensors + const int* cpu_ptr = cpu_tensor.data(); + const int* src_ptr = src_vec.data(); + ASSERT_NE(src_ptr, cpu_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + } + + src_vec.erase(src_vec.begin(), src_vec.begin() + 5); + cpu_tensor.Resize(make_ddim({2, 2})); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + cpu_ptr = cpu_tensor.data(); + src_ptr = src_vec.data(); + ASSERT_NE(src_ptr, cpu_ptr); + for (size_t i = 0; i < 5; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + } + + delete cpu_place; + } + +#ifdef PADDLE_WITH_CUDA + { + std::vector src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + Tensor cpu_tensor; + Tensor gpu_tensor; + Tensor dst_tensor; + + // Copy to CPU Tensor + cpu_tensor.Resize(make_ddim({3, 3})); + auto cpu_place = new paddle::platform::CPUPlace(); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + + // Copy to GPUTensor + gpu_tensor.Resize(make_ddim({3, 3})); + auto gpu_place = new paddle::platform::GPUPlace(); + gpu_tensor.CopyFromVector(src_vec, *gpu_place); + // Copy from GPU to CPU tensor for comparison + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + + // Compare Tensors + const int* src_ptr = src_vec.data(); + const int* cpu_ptr = cpu_tensor.data(); + const int* dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, cpu_ptr); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + + src_vec.erase(src_vec.begin(), src_vec.begin() + 5); + + cpu_tensor.Resize(make_ddim({2, 2})); + cpu_tensor.CopyFromVector(src_vec, *cpu_place); + gpu_tensor.Resize(make_ddim({2, 2})); + gpu_tensor.CopyFromVector(src_vec, *gpu_place); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + + src_ptr = src_vec.data(); + cpu_ptr = cpu_tensor.data(); + dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, cpu_ptr); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 5; ++i) { + EXPECT_EQ(src_ptr[i], cpu_ptr[i]); + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + + delete cpu_place; + delete gpu_place; + } +#endif +} + TEST(Tensor, ReshapeToMatrix) { using namespace paddle::framework; using namespace paddle::platform; From 5984cbca47a4663b47b16390fc028829dbc9f183 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Oct 2017 14:30:31 -0700 Subject: [PATCH 14/26] Add Attr test --- paddle/framework/backward_test.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 30225a4a99..05ebf356ba 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -440,6 +440,25 @@ TEST(Backward, simple_single_op) { std::vector({f::GradVarName("b")})); } +TEST(Backward, default_attribute) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + f::OpDescBind *op = block->AppendOp(); + op->SetType("mul"); + op->SetInput("X", {"x"}); + op->SetInput("Y", {"y"}); + op->SetOutput("Out", {"out"}); + + AppendBackward(program, {}); + + ASSERT_EQ(block->AllOps().size(), 2UL); + f::OpDescBind *grad_op = block->AllOps()[1]; + ASSERT_EQ(grad_op->Type(), "mul_grad"); + EXPECT_EQ(boost::get(grad_op->GetAttr("x_num_col_dims")), 1); + EXPECT_EQ(boost::get(grad_op->GetAttr("y_num_col_dims")), 1); +} + TEST(Backward, simple_mult_op) { f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); From ce901b1186b671781cd86b91ce530e2be3408f37 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Oct 2017 15:16:34 -0700 Subject: [PATCH 15/26] Refine unit test --- paddle/framework/backward_test.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 05ebf356ba..3b7cbcd989 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -58,6 +58,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "A"); AddInput("Y", "B"); AddOutput("Out", "Out"); + AddAttr("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); + AddAttr("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); AddComment("Mul"); } }; @@ -453,6 +455,9 @@ TEST(Backward, default_attribute) { AppendBackward(program, {}); ASSERT_EQ(block->AllOps().size(), 2UL); + EXPECT_EQ(boost::get(op->GetAttr("x_num_col_dims")), 1); + EXPECT_EQ(boost::get(op->GetAttr("y_num_col_dims")), 1); + f::OpDescBind *grad_op = block->AllOps()[1]; ASSERT_EQ(grad_op->Type(), "mul_grad"); EXPECT_EQ(boost::get(grad_op->GetAttr("x_num_col_dims")), 1); From d2217ffa39f30ec719e93cd2cb6b9449467bcbc1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 16:03:39 -0700 Subject: [PATCH 16/26] Fix CI --- paddle/framework/block_desc.h | 1 + paddle/framework/program_desc.h | 1 + 2 files changed, 2 insertions(+) diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 2de270f60e..3437e89923 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/framework/op_desc.h" diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index d684b08d16..f29b1c54e7 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/framework/framework.pb.h" #include "paddle/platform/macros.h" From 6efacc14d857bd117d5918bf02afc9cca702bd78 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Mon, 9 Oct 2017 17:24:21 -0700 Subject: [PATCH 17/26] Implementing the fill constant op for the executor --- paddle/operators/fill_constant_op.cc | 68 +++++++++++++++++++ paddle/operators/fill_constant_op.cu | 22 ++++++ paddle/operators/fill_constant_op.h | 37 ++++++++++ .../framework/tests/test_fill_constant_op.py | 35 ++++++++++ 4 files changed, 162 insertions(+) create mode 100644 paddle/operators/fill_constant_op.cc create mode 100644 paddle/operators/fill_constant_op.cu create mode 100644 paddle/operators/fill_constant_op.h create mode 100644 python/paddle/v2/framework/tests/test_fill_constant_op.py diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc new file mode 100644 index 0000000000..65d03d5fa4 --- /dev/null +++ b/paddle/operators/fill_constant_op.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_constant_op.h" + +namespace paddle { +namespace operators { + +class FillConstantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FillConstantOp should not be null."); + auto &shape = ctx->Attrs().Get>("shape"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto dims = framework::make_ddim(shape_int64); + ctx->SetOutputDim("Out", dims); + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return static_cast(ctx.Attr("dataType")); + } +}; + +class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FillConstantOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", + "(int, default 5 (FP32)) " + "Output data type") + .SetDefault(framework::DataType::FP32); + AddAttr>("shape", "(vector) The shape of the output"); + AddAttr("value", "(float, default 0) The value to be filled") + .SetDefault(0.0f); + AddOutput("Out", + "(Tensor) Tensor of specified shape will be filled " + "with the specified value"); + AddComment(R"DOC(Fill up a variable with specified constant value.)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp, + ops::FillConstantOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.cu b/paddle/operators/fill_constant_op.cu new file mode 100644 index 0000000000..eef8fcbd7f --- /dev/null +++ b/paddle/operators/fill_constant_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_constant_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.h b/paddle/operators/fill_constant_op.h new file mode 100644 index 0000000000..53b8b548ec --- /dev/null +++ b/paddle/operators/fill_constant_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FillConstantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto value = ctx.Attr("value"); + + auto out_eigen = framework::EigenVector::Flatten(*out); + auto place = ctx.GetEigenDevice(); + out_eigen.device(place) = out_eigen.constant(static_cast(value)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_fill_constant_op.py b/python/paddle/v2/framework/tests/test_fill_constant_op.py new file mode 100644 index 0000000000..dff7b615aa --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fill_constant_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFillConstantOp1(OpTest): + def setUp(self): + '''Test fill_constant op with specified value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92], 'value': 3.8} + self.outputs = {'Out': np.full((123, 92), 3.8)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp2(OpTest): + def setUp(self): + '''Test fill_constant op with default value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92]} + self.outputs = {'Out': np.full((123, 92), 0.0)} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From c876cacc7b7936f356b0c712dff8f32534eab2ae Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 10 Oct 2017 10:24:18 +0800 Subject: [PATCH 18/26] =?UTF-8?q?fix=20compiler=20error:=20=E2=80=98unique?= =?UTF-8?q?=5Fptr=E2=80=99=20is=20not=20a=20member=20of=20=E2=80=98std?= =?UTF-8?q?=E2=80=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/framework/type_defs.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index a5b9472213..6f65a942ba 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include "paddle/platform/variant.h" namespace paddle { From 6c6474cbd8514011b1c63d3439d49bd4700e46c8 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 10:32:19 +0800 Subject: [PATCH 19/26] follow coments --- paddle/operators/CMakeLists.txt | 15 +++---- paddle/operators/math/pooling.h | 23 ++++++----- paddle/operators/pool_with_index_op.cc | 57 +++++++++++++++----------- 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 39af318ca5..31ae4b2cc1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,12 +55,20 @@ function(op_library TARGET) set(pybind_flag 1) endif() + # pool_op contains several operators if ("${TARGET}" STREQUAL "pool_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() + # pool_with_index_op contains several operators + if ("${TARGET}" STREQUAL "pool_with_index_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") + endif() + # activation_op contains several operators if ("${TARGET}" STREQUAL "activation_op") set(pybind_flag 1) @@ -75,13 +83,6 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") endif() - # pool_with_index_op contains several operators - if ("${TARGET}" STREQUAL "pool_with_index_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") - endif() - # pybind USE_NO_KERNEL_OP file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}") diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index f15ddca69a..c50c57b5c5 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -24,15 +24,16 @@ namespace math { #define FLT_MAX \ __FLT_MAX__ // It might need to be placed in another file, but I'm still - // wondering where to put it + // wondering where to put it. /* * \brief Extracting simple operations from pooling. - * Both MaxPool and AvgPool need initial, compute and finalize operation. + * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * operation. * MaxPool initializes temp variable to the negative maximum to find the * maximum value in the pooling field. * AvgPool initializes temp variable to the zero to accumulate all values - * in pool pooling, and takes the average. + * in pool pooling, and finally takes the average. * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. */ template @@ -72,17 +73,17 @@ class AvgPoolGrad { /* * \brief Getting pooling results, and calculating gradient. * - * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in - * NCDHW format. + * In pool2d, all tensors are in NCHW format. Where N is batch size, C is the + * number of channels, H and W is the height and width of feature. + * In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the + * number of channels, D, H and W is the depth, height and width of feature. * * In max pooling, it is possible that the pooling region has multiple maximum - * elements. - * In this case, we should compute the gradient of the first maximum element. + * elements. In this case, we should compute the gradient of the first maximum + * element. * This is different from average pooling. So we rewrite the max_pool_grad: * MaxPool2dGradFunctor, MaxPool3dGradFunctor. - * */ - template class Pool2dFunctor { public: @@ -146,10 +147,9 @@ class MaxPool3dGradFunctor { /* * \brief Getting max pooling results and corresponding max index, and * calculating gradient. - * In sub-sampling-pooling, it is necessary to know max element index. + * In up-sampling-pooling, it is necessary to know max element index. * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in * NCDHW format. - * */ template class MaxPool2dWithIndexFunctor { @@ -188,6 +188,7 @@ class MaxPool3dWithIndexGradFunctor { const framework::Tensor& mask, std::vector& ksize, std::vector& strides, std::vector& paddings); }; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index 2e6a5f2555..ab933a3400 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -34,7 +34,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Out(Output) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Mask"), - "Out(Output) of Pooling should not be null."); + "Mask(Output) of Pooling should not be null."); auto in_x_dims = ctx->GetInputDim("X"); @@ -52,13 +52,11 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { } PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, - "Pooling intput size and pooling size should be consistent"); - PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, - "Pooling size size should be 2 elements. or 3 elements."); + "Intput size and pooling size should be consistent."); PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), - "strides size and pooling size should be the same."); + "Strides size and pooling size should be the same."); PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), - "paddings size and pooling size should be the same."); + "Paddings size and pooling size should be the same."); std::vector output_shape({in_x_dims[0], in_x_dims[1]}); for (size_t i = 0; i < ksize.size(); ++i) { @@ -76,11 +74,9 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "X(Input) of Pooling should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput(framework::GradVarName("X")), - "X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; @@ -110,9 +106,10 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "ksize", - "Pooling size(height, width) of pooling operator." + "The pooling size(height, width) of pooling operator." "If globalPooling = true, ksize is ignored and need not be " - "specified."); // TODO(Add checker) + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) AddAttr( "globalPooling", "Whether to use the globalPooling." @@ -123,15 +120,21 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("strides", "Strides(height, width) of pooling operator." "Default {1,1}.") - .SetDefault({1, 1}); // TODO(Add checker) + .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) AddAttr>("paddings", "Paddings(height, width) of pooling operator." "Default {0,0}.") - .SetDefault({0, 0}); // TODO(Add checker) + .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( -The maxPooling2d with index operation calculates the output and the mask based on -the input and ksize, strides, paddings parameters. +The maxPooling2d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. Input(X) and +output(Out, Mask) are in NCHW format. Where N is batch size, C is the +number of channels, H and W is the height and width of feature. +Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. )DOC"); } }; @@ -162,9 +165,10 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "ksize", - "Pooling size(depth, height, width) of pooling operator." + "The pooling size(depth, height, width) of pooling operator." "If globalPooling = true, ksize is ignored and need not be " - "specified."); // TODO(Add checker) + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) AddAttr( "globalPooling", "Whether to use the globalPooling." @@ -176,19 +180,26 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { "strides", "Strides(depth, height, width) of pooling operator." "Default {1,1,1}.") - .SetDefault({1, 1, 1}); // TODO(Add checker) + .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", "Paddings(depth, height, width) of pooling operator." "Default {0,0,0}.") - .SetDefault({0, 0, 0}); // TODO(Add checker) + .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( -The maxpooling3d with index operation calculates the output and the mask based on -the input and ksize, strides, paddings parameters. +The maxpooling3d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. +Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch +size, C is the number of channels, D, H and W is the depth, height and +width of feature. Parameters(ksize, strides, paddings) are three elements. +These three elements represent depth, height and width, respectively. )DOC"); } }; + } // namespace operators } // namespace paddle From 0e1f21a57050df474a4bbf47f538ddf1b4ef4f61 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Oct 2017 20:43:29 -0700 Subject: [PATCH 20/26] Fix bug --- paddle/framework/op_desc.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 02aa74a842..c2e796b7c1 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -25,6 +25,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, inputs_ = inputs; outputs_ = outputs; attrs_ = attrs; + need_update_ = true; } OpDesc *OpDescBind::Proto() { From 32cb74be3ebea9c9c59602576f45086934308789 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 21:06:01 -0700 Subject: [PATCH 21/26] Removed unreached code --- paddle/framework/data_type.h | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index 55e3931f87..649899d425 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -28,7 +28,6 @@ inline DataType ToDataType(std::type_index type) { return DataType::INT32; } else { PADDLE_THROW("Not supported"); - return static_cast(-1); } } From e21e5646a574b9e2fa299bacb3a8ee85472e84b5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 13:55:27 +0800 Subject: [PATCH 22/26] fix atomicAdd -> CudaAtomicAdd --- paddle/operators/math/pooling.cu | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 06263737a9..4d50121de4 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -144,7 +144,7 @@ __global__ void KernelMaxPool2DGrad( if (maxIndex != -1) { // atomic add - atomicAdd(input_grad + maxIndex, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); } } } @@ -278,9 +278,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -453,7 +451,7 @@ __global__ void KernelMaxPool3DGrad( } if (maxIdx != -1) { // atomic add - atomicAdd(input_grad + maxIdx, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]); } } } @@ -609,9 +607,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; From 871a3f6e76f57432d64b0410f49277a6e4f7d477 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 10 Oct 2017 15:18:02 +0800 Subject: [PATCH 23/26] remove unused PADDLE_ONLY_CPU comment --- paddle/math/tests/test_GpuProfiler.cpp | 2 +- paddle/memory/detail/buddy_allocator.cc | 2 +- paddle/memory/detail/system_allocator.cc | 2 +- paddle/memory/detail/system_allocator.h | 2 +- paddle/memory/detail/system_allocator_test.cc | 2 +- paddle/memory/memcpy.cc | 2 +- paddle/memory/memcpy.h | 2 +- paddle/memory/memory.cc | 2 +- paddle/memory/memory_test.cc | 2 +- paddle/platform/device_context.cc | 2 +- paddle/platform/enforce.h | 2 +- paddle/platform/gpu_info.h | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/math/tests/test_GpuProfiler.cpp b/paddle/math/tests/test_GpuProfiler.cpp index 9402bd3ec4..d9f146f0d1 100644 --- a/paddle/math/tests/test_GpuProfiler.cpp +++ b/paddle/math/tests/test_GpuProfiler.cpp @@ -162,4 +162,4 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } -#endif /* PADDLE_ONLY_CPU */ +#endif diff --git a/paddle/memory/detail/buddy_allocator.cc b/paddle/memory/detail/buddy_allocator.cc index fdc5ed19dc..e212f7737a 100644 --- a/paddle/memory/detail/buddy_allocator.cc +++ b/paddle/memory/detail/buddy_allocator.cc @@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { max_chunk_size_ = platform::GpuMaxChunkSize(); } } -#endif // PADDLE_ONLY_CPU +#endif // Allocate a new maximum sized block size_t index = 0; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 6c9a46dd09..33166d9ce2 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { bool GPUAllocator::UseGpu() const { return true; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator.h b/paddle/memory/detail/system_allocator.h index ee9b012f91..552cab4f96 100644 --- a/paddle/memory/detail/system_allocator.h +++ b/paddle/memory/detail/system_allocator.h @@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator { size_t gpu_alloc_size_ = 0; size_t fallback_alloc_size_ = 0; }; -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator_test.cc b/paddle/memory/detail/system_allocator_test.cc index cd563844e7..6a8558937b 100644 --- a/paddle/memory/detail/system_allocator_test.cc +++ b/paddle/memory/detail/system_allocator_test.cc @@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) { TestAllocator(a, 2048); TestAllocator(a, 0); } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index 790420a8ab..1df88a6da9 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -89,7 +89,7 @@ void Copy(platform::GPUPlace dst_place, platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index 0bccee58c3..9b36182c2b 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -53,7 +53,7 @@ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, cudaStream_t stream); -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 30ce8a82e1..5087c02385 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -111,7 +111,7 @@ size_t Used(platform::GPUPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory_test.cc b/paddle/memory/memory_test.cc index 0d402038a0..2444931e26 100644 --- a/paddle/memory/memory_test.cc +++ b/paddle/memory/memory_test.cc @@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) { } } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index a9b6b79903..36450e9268 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -136,7 +136,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace platform } // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 15d8446cd8..cd906c3fa9 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -41,7 +41,7 @@ limitations under the License. */ #include #include -#endif // PADDLE_ONLY_CPU +#endif namespace paddle { namespace platform { diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index fb33db07bd..37665b97d7 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -63,4 +63,4 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, } // namespace platform } // namespace paddle -#endif // PADDLE_ONLY_CPU +#endif From 36da82550af759fcfcaec571921851dd04bc4a3b Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 16:28:24 +0800 Subject: [PATCH 24/26] Add code comments --- paddle/operators/math/pooling.cc | 54 ++++++++++++++++++++++++++++++-- paddle/operators/math/pooling.cu | 50 +++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 5accde8b07..50cfb88bb5 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -18,6 +18,11 @@ namespace paddle { namespace operators { namespace math { +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -73,6 +78,11 @@ class Pool2dFunctor { } }; +/* +* All tensors are in NCHW format. +* Ksize, strides, paddings are two elements. These two elements represent height +* and width, respectively. +*/ template class Pool2dGradFunctor { public: @@ -135,6 +145,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -197,7 +212,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -216,6 +231,11 @@ template class Pool2dGradFunctor< template class Pool2dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -286,6 +306,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -364,6 +389,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -440,7 +470,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -459,6 +489,11 @@ template class Pool3dGradFunctor< template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dWithIndexFunctor { public: @@ -519,6 +554,11 @@ class MaxPool2dWithIndexFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dWithIndexGradFunctor { public: @@ -563,6 +603,11 @@ template class MaxPool2dWithIndexGradFunctor; template class MaxPool2dWithIndexFunctor; template class MaxPool2dWithIndexGradFunctor; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dWithIndexFunctor { public: @@ -637,6 +682,11 @@ class MaxPool3dWithIndexFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dWithIndexGradFunctor { public: diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 4d50121de4..736327f4b7 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -149,6 +149,11 @@ __global__ void KernelMaxPool2DGrad( } } +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -190,6 +195,11 @@ class Pool2dFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dGradFunctor { public: @@ -234,6 +244,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -456,6 +471,11 @@ __global__ void KernelMaxPool3DGrad( } } +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -504,6 +524,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -556,6 +581,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -709,6 +739,11 @@ __global__ void KernelMaxPool2DWithIdxGrad( } } +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dWithIndexFunctor { public: @@ -750,6 +785,11 @@ class MaxPool2dWithIndexFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dWithIndexGradFunctor { public: @@ -903,6 +943,11 @@ __global__ void KernelMaxPool3DWithIdxGrad( } } +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dWithIndexFunctor { public: @@ -951,6 +996,11 @@ class MaxPool3dWithIndexFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dWithIndexGradFunctor { public: From f2e7cf21415fbdc0ae2f34b88b6cf307b37966f0 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 22:08:33 +0800 Subject: [PATCH 25/26] fix InferShapeContextBase to InferShapeContext --- paddle/operators/pool_with_index_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index ab933a3400..7b6afcfd1f 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -28,7 +28,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -73,7 +73,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Input(X@GRAD) should not be null."); From a281b38393597e9c6342d365b3e0b7371194b97e Mon Sep 17 00:00:00 2001 From: Markus Kliegl Date: Tue, 10 Oct 2017 10:53:02 -0700 Subject: [PATCH 26/26] Conv Shift Operator (#4591) * conv_shift_op: initial implementation using Eigen Limitations: - both gradient outputs must be specified and are always computed - explicit for loops => could be optimized in various ways (e.g., different memory layout) * conv shift - gradient fixes fix case when not all output gradients desired * conv shift: minor cleanup * conv shift - more minor cleanup * conv shift: clean up & initial GPU implementation * fix rebase issue --- paddle/operators/conv_shift_op.cc | 206 ++++++++++++++++++ paddle/operators/conv_shift_op.cu | 194 +++++++++++++++++ paddle/operators/conv_shift_op.h | 33 +++ .../v2/framework/tests/test_conv_shift_op.py | 47 ++++ 4 files changed, 480 insertions(+) create mode 100644 paddle/operators/conv_shift_op.cc create mode 100644 paddle/operators/conv_shift_op.cu create mode 100644 paddle/operators/conv_shift_op.h create mode 100644 python/paddle/v2/framework/tests/test_conv_shift_op.py diff --git a/paddle/operators/conv_shift_op.cc b/paddle/operators/conv_shift_op.cc new file mode 100644 index 0000000000..e1e321ed5f --- /dev/null +++ b/paddle/operators/conv_shift_op.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/framework/eigen.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +class ConvShiftOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The 1st dimension of Input(X) and Input(Y) should " + "be equal."); + PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, + "The 2nd dimension of Input(Y) should be odd."); + PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], + "The 2nd dimension of Input(Y) should be less than or " + "equal to the 2nd dimension of Input(X)."); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class ConvShiftGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should be not null."); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(y_grad_name)) { + auto y_dims = ctx->GetInputDim("Y"); + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ConvShiftOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "where B is the batch size and M is the data dimension."); + AddInput("Y", + "(Tensor, default Tensor), a 2-D tensor with shape B x N, " + "where B is the batch size and N is the data dimension. N must " + "be odd."); + AddOutput("Out", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "i.e., the same shape as X."); + AddComment(R"DOC( +ConvShift Operator. + +A layer for circular convolution of two vectors, +as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401 + +The equation is: + + \f[ + Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j} + \f] + +where X's index is computed modulo M, and b's index is computed modulo N. + +Both of the input `X` and `Y` can carry LoD (Level of Details) information. +However, the output only shares the LoD information with input `X`. +)DOC"); + } +}; + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto out = EigenMatrix::From(*Out); + out.setZero(); + + size_t batch_size = X->dims()[0]; + size_t x_width = X->dims()[1]; + size_t y_width = Y->dims()[1]; + size_t y_half_width = (y_width - 1) / 2; + + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + out(k, i) += x(k, index) * y(k, j); + } + } + } + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *dOut = context.Input(framework::GradVarName("Out")); + auto *dX = context.Output(framework::GradVarName("X")); + auto *dY = context.Output(framework::GradVarName("Y")); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto dout = EigenMatrix::From(*dOut); + + auto x_dims = X->dims(); + auto y_dims = Y->dims(); + size_t batch_size = x_dims[0]; + size_t x_width = x_dims[1]; + size_t y_width = y_dims[1]; + size_t y_half_width = (y_width - 1) / 2; + + // The below trades code duplication for efficiency (keeping the if + // statement outside of the loop). + if (dX) { + dX->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::From(*dX); + dx.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dx(k, index) += dout(k, i) * y(k, j); + } + } + } + } + + if (dY) { + dY->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::From(*dY); + dy.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dy(k, j) += x(k, index) * dout(k, i); + } + } + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker, + conv_shift_grad, ops::ConvShiftGradOp); +REGISTER_OP_CPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_CPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu new file mode 100644 index 0000000000..145e966fe9 --- /dev/null +++ b/paddle/operators/conv_shift_op.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +namespace { + +inline int div_up(int x, int y) { return (x + y - 1) / y; } + +// Some notes on the design: +// +// Each thread is responsible for computing a single output out[k, i]. +// Thread blocks are based on tiles of x with height 1 in the batch dimension. +// +// This design is based on the typical use case where the filter +// y is fairly small. For large y, it would probably be more efficient +// to also tile across y. +template +__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, + int y_width, int y_half_width, + int batch_size) { + extern __shared__ T mem[]; + + int tx = threadIdx.x; + int i = blockIdx.x * blockDim.x + tx; // global x index + int k = blockIdx.y; // batch index + + // Check if we are in a boundary block with fewer x's to process than + // blockDim.x. + int num_x = + (blockIdx.x == gridDim.x - 1) ? (x_width % blockDim.x) : blockDim.x; + + T *sx = mem; + T *sx_pad = &mem[num_x]; + T *sy = &mem[blockDim.x + y_width]; + + // Collaboratively load y[k, :] and length-y padding of x into shared memory. + int pad_start = blockIdx.x * blockDim.x + num_x + x_width - y_half_width; + for (int j = tx; j < y_width; j += blockDim.x) { + sy[j] = y[k * y_width + j]; + sx_pad[j] = x[k * x_width + (pad_start + j) % x_width]; + } + + // Load a cyclically shifted slice of x into shared memory. + if (tx < num_x) { + int load_i = (i - y_half_width + x_width) % x_width; + sx[tx] = x[k * x_width + load_i]; + } else { + return; + } + __syncthreads(); + + // Compute dot product of sx[tx:tx + y_width] and sy. + T sum = 0; + for (int j = 0; j < y_width; ++j) { + sum += sx[tx + j] * sy[j]; + } + + // Save to out[k, i]. + out[k * x_width + i] = sum; +} + +// Compute x gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dx[k * x_width + index], + dout[k * x_width + i] * y[k * y_width + j]); + } +} + +// Compute y gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dy[k * y_width + j], + x[k * x_width + index] * dout[k * x_width + i]); + } +} +} // namespace + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + Tensor *Out = context.Output("Out"); + const T *x_data = X->data(); + const T *y_data = Y->data(); + T *out_data = Out->mutable_data(context.GetPlace()); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); + + dim3 grid_dim(num_x_blocks, batch_size); + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + conv_shift_forward<<>>( + x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + const T *x_data = X->data(); + const T *y_data = Y->data(); + const T *dout_data = dOut->data(); + + Tensor *dX = context.Output(framework::GradVarName("X")); + Tensor *dY = context.Output(framework::GradVarName("Y")); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + dim3 grid_dim(num_x_blocks, y_width, batch_size); + + if (dX) { + T *dx_data = dX->mutable_data(context.GetPlace()); + cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); + conv_shift_dx<<>>( + dout_data, y_data, dx_data, x_width, y_width, y_half_width, + batch_size); + } + if (dY) { + T *dy_data = dY->mutable_data(context.GetPlace()); + cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); + conv_shift_dy<<>>( + x_data, dout_data, dy_data, x_width, y_width, y_half_width, + batch_size); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_GPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.h b/paddle/operators/conv_shift_op.h new file mode 100644 index 0000000000..5a160b0f16 --- /dev/null +++ b/paddle/operators/conv_shift_op.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; + +template +class ConvShiftGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_conv_shift_op.py b/python/paddle/v2/framework/tests/test_conv_shift_op.py new file mode 100644 index 0000000000..b9ab21a06a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv_shift_op.py @@ -0,0 +1,47 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def conv_shift_forward(x, y): + out = np.zeros_like(x) + M = x.shape[1] + N = y.shape[1] + y_half_width = (N - 1) / 2 + for i in xrange(M): + for j in xrange(N): + out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j] + return out + + +class TestConvShiftOp(OpTest): + def setUp(self): + self.op_type = "conv_shift" + + batch_size = 4 + x_dim = 17 + y_dim = 3 # must be odd and <= x_dim + x = np.random.random((batch_size, x_dim)).astype("float32") + y = np.random.random((batch_size, y_dim)).astype("float32") + self.inputs = {'X': x, 'Y': y} + + out = conv_shift_forward(x, y) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X")) + + def test_check_grad_ignore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) + + +if __name__ == '__main__': + unittest.main()