|
|
|
@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
|
|
|
|
|
const int ksize_width, const int stride_height,
|
|
|
|
|
const int stride_width, const int padding_height,
|
|
|
|
|
const int padding_width, PoolProcess pool_process,
|
|
|
|
|
T* output_data) {
|
|
|
|
|
bool exclusive, T* output_data) {
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
int pw = index % output_width;
|
|
|
|
@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
|
|
|
|
|
pool_process.compute(input_data[h * input_width + w], &ele);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int pool_size = (hend - hstart) * (wend - wstart);
|
|
|
|
|
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
|
|
|
|
|
: ksize_height * ksize_width;
|
|
|
|
|
pool_process.finalize(static_cast<T>(pool_size), &ele);
|
|
|
|
|
output_data[index] = ele;
|
|
|
|
|
}
|
|
|
|
@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
|
|
|
|
|
const int input_width, const int output_height, const int output_width,
|
|
|
|
|
const int ksize_height, const int ksize_width, const int stride_height,
|
|
|
|
|
const int stride_width, const int padding_height, const int padding_width,
|
|
|
|
|
PoolProcess pool_process, T* input_grad) {
|
|
|
|
|
PoolProcess pool_process, bool exclusive, T* input_grad) {
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
int offsetW = index % input_width + padding_width;
|
|
|
|
@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
|
|
|
|
|
int wend = min(wstart + ksize_width, input_width);
|
|
|
|
|
hstart = max(hstart, 0);
|
|
|
|
|
wstart = max(wstart, 0);
|
|
|
|
|
int pool_size = (hend - hstart) * (wend - wstart);
|
|
|
|
|
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
|
|
|
|
|
: ksize_height * ksize_width;
|
|
|
|
|
int output_sub_idx = ph * output_width + pw;
|
|
|
|
|
pool_process.compute(input, output_data[output_sub_idx],
|
|
|
|
|
output_grad[output_sub_idx],
|
|
|
|
@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
const framework::Tensor& input, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_process,
|
|
|
|
|
framework::Tensor* output) {
|
|
|
|
|
bool exclusive, framework::Tensor* output) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
nthreads, input_data, input_channels, input_height, input_width,
|
|
|
|
|
output_height, output_width, ksize_height, ksize_width, stride_height,
|
|
|
|
|
stride_width, padding_height, padding_width, pool_process, output_data);
|
|
|
|
|
stride_width, padding_height, padding_width, pool_process, exclusive,
|
|
|
|
|
output_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_process,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
bool exclusive, framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
nthreads, input_data, output_data, output_grad_data, input_channels,
|
|
|
|
|
input_height, input_width, output_height, output_width, ksize_height,
|
|
|
|
|
ksize_width, stride_height, stride_width, padding_height, padding_width,
|
|
|
|
|
pool_process, input_grad_data);
|
|
|
|
|
pool_process, exclusive, input_grad_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
|
|
|
|
|
double>;
|
|
|
|
|
|
|
|
|
|
template <typename PoolProcess, typename T>
|
|
|
|
|
__global__ void KernelPool3D(const int nthreads, const T* input_data,
|
|
|
|
|
const int channels, const int input_depth,
|
|
|
|
|
const int input_height, const int input_width,
|
|
|
|
|
const int output_depth, const int output_height,
|
|
|
|
|
const int output_width, const int ksize_depth,
|
|
|
|
|
const int ksize_height, const int ksize_width,
|
|
|
|
|
const int stride_depth, const int stride_height,
|
|
|
|
|
const int stride_width, const int padding_depth,
|
|
|
|
|
const int padding_height, const int padding_width,
|
|
|
|
|
PoolProcess pool_process, T* output_data) {
|
|
|
|
|
__global__ void KernelPool3D(
|
|
|
|
|
const int nthreads, const T* input_data, const int channels,
|
|
|
|
|
const int input_depth, const int input_height, const int input_width,
|
|
|
|
|
const int output_depth, const int output_height, const int output_width,
|
|
|
|
|
const int ksize_depth, const int ksize_height, const int ksize_width,
|
|
|
|
|
const int stride_depth, const int stride_height, const int stride_width,
|
|
|
|
|
const int padding_depth, const int padding_height, const int padding_width,
|
|
|
|
|
PoolProcess pool_process, bool exclusive, T* output_data) {
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
int pw = index % output_width;
|
|
|
|
@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
|
|
|
|
|
int pool_size = exclusive ?
|
|
|
|
|
(dend - dstart) * (hend - hstart) * (wend - wstart)
|
|
|
|
|
: ksize_depth * ksize_height * ksize_width;
|
|
|
|
|
pool_process.finalize(static_cast<T>(pool_size), &ele);
|
|
|
|
|
output_data[index] = ele;
|
|
|
|
|
}
|
|
|
|
@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
|
|
|
|
|
const int ksize_height, const int ksize_width, const int stride_depth,
|
|
|
|
|
const int stride_height, const int stride_width, const int padding_depth,
|
|
|
|
|
const int padding_height, const int padding_width, PoolProcess pool_process,
|
|
|
|
|
T* input_grad) {
|
|
|
|
|
bool exclusive, T* input_grad) {
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
int offsetW = index % input_width + padding_width;
|
|
|
|
@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
|
|
|
|
|
dstart = max(dstart, 0);
|
|
|
|
|
hstart = max(hstart, 0);
|
|
|
|
|
wstart = max(wstart, 0);
|
|
|
|
|
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
|
|
|
|
|
int pool_size = exclusive ?
|
|
|
|
|
(dend - dstart) * (hend - hstart) * (wend - wstart)
|
|
|
|
|
: ksize_depth * ksize_height * ksize_width;
|
|
|
|
|
int output_sub_idx = (pd * output_height + ph) * output_width + pw;
|
|
|
|
|
pool_process.compute(input, output_data[output_sub_idx],
|
|
|
|
|
output_grad[output_sub_idx],
|
|
|
|
@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
const framework::Tensor& input, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_process,
|
|
|
|
|
framework::Tensor* output) {
|
|
|
|
|
bool exclusive, framework::Tensor* output) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
|
const int input_depth = input.dims()[2];
|
|
|
|
@ -518,7 +523,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
input_width, output_depth, output_height, output_width, ksize_depth,
|
|
|
|
|
ksize_height, ksize_width, stride_depth, stride_height, stride_width,
|
|
|
|
|
padding_depth, padding_height, padding_width, pool_process,
|
|
|
|
|
output_data);
|
|
|
|
|
exclusive, output_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_process,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
bool exclusive, framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
|
const int input_depth = input.dims()[2];
|
|
|
|
@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
|
|
|
|
|
input_depth, input_height, input_width, output_depth, output_height,
|
|
|
|
|
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
|
|
|
|
|
stride_height, stride_width, padding_depth, padding_height,
|
|
|
|
|
padding_width, pool_process, input_grad_data);
|
|
|
|
|
padding_width, pool_process, exclusive, input_grad_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|