|
|
|
@ -21,11 +21,12 @@ namespace math {
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
__global__ void vol2col(int num_kernels, const T* data_vol, int depth,
|
|
|
|
|
int height, int width, int filter_depth,
|
|
|
|
|
int filter_height, int filter_width, int stride_depth,
|
|
|
|
|
int stride_height, int stride_width, int padding_depth,
|
|
|
|
|
int padding_height, int padding_width, int output_detph,
|
|
|
|
|
int output_height, int output_width, T* data_col) {
|
|
|
|
|
int height, int width, int dilation_d, int dilation_h,
|
|
|
|
|
int dilation_w, int filter_depth, int filter_height,
|
|
|
|
|
int filter_width, int stride_depth, int stride_height,
|
|
|
|
|
int stride_width, int padding_depth, int padding_height,
|
|
|
|
|
int padding_width, int output_detph, int output_height,
|
|
|
|
|
int output_width, T* data_col) {
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
int w_out = index % output_width;
|
|
|
|
@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
|
|
|
|
|
for (int k = 0; k < filter_depth; ++k) {
|
|
|
|
|
for (int i = 0; i < filter_height; ++i) {
|
|
|
|
|
for (int j = 0; j < filter_width; ++j) {
|
|
|
|
|
int d = d_in + k;
|
|
|
|
|
int h = h_in + i;
|
|
|
|
|
int w = w_in + j;
|
|
|
|
|
int d = d_in + k * dilation_d;
|
|
|
|
|
int h = h_in + i * dilation_h;
|
|
|
|
|
int w = w_in + j * dilation_w;
|
|
|
|
|
int col_idx = (k * dilation_d * height + i * dilation_h) * width +
|
|
|
|
|
j * dilation_w;
|
|
|
|
|
*data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
|
|
|
|
|
w < width)
|
|
|
|
|
? data_vol[(k * height + i) * width + j]
|
|
|
|
|
? data_vol[col_idx]
|
|
|
|
|
: 0;
|
|
|
|
|
data_col += output_detph * output_height * output_width;
|
|
|
|
|
}
|
|
|
|
@ -69,6 +72,7 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
const framework::Tensor& vol, framework::Tensor& col,
|
|
|
|
|
int dilation_d, int dilation_h, int dilation_w,
|
|
|
|
|
int stride_depth, int stride_height, int stride_width,
|
|
|
|
|
int padding_depth, int padding_height,
|
|
|
|
|
int padding_width) const {
|
|
|
|
@ -86,6 +90,28 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
|
|
|
|
|
int output_height = col.dims()[5];
|
|
|
|
|
int output_width = col.dims()[6];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth -
|
|
|
|
|
((dilation_d * (filter_depth - 1) + 1))) /
|
|
|
|
|
stride_depth +
|
|
|
|
|
1,
|
|
|
|
|
output_depth,
|
|
|
|
|
"input_depth and output_depth are "
|
|
|
|
|
"Mismatching.");
|
|
|
|
|
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height -
|
|
|
|
|
((dilation_h * (filter_height - 1) + 1))) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1,
|
|
|
|
|
output_height,
|
|
|
|
|
"input_height and output_height are "
|
|
|
|
|
"Mismatching.");
|
|
|
|
|
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width -
|
|
|
|
|
((dilation_w * (filter_width - 1) + 1))) /
|
|
|
|
|
stride_width +
|
|
|
|
|
1,
|
|
|
|
|
output_width,
|
|
|
|
|
"input_width and output_width are "
|
|
|
|
|
"Mismatching.");
|
|
|
|
|
|
|
|
|
|
int num_outputs =
|
|
|
|
|
input_channels * output_depth * output_height * output_width;
|
|
|
|
|
|
|
|
|
@ -95,19 +121,25 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
|
|
|
|
|
filter_depth, filter_height, filter_width, stride_depth, stride_height,
|
|
|
|
|
stride_width, padding_depth, padding_height, padding_width,
|
|
|
|
|
output_depth, output_height, output_width, col.data<T>());
|
|
|
|
|
dilation_d, dilation_h, dilation_w, filter_depth, filter_height,
|
|
|
|
|
filter_width, stride_depth, stride_height, stride_width, padding_depth,
|
|
|
|
|
padding_height, padding_width, output_depth, output_height,
|
|
|
|
|
output_width, col.data<T>());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
__global__ void col2vol(int num_kernels, const T* data_col, int depth,
|
|
|
|
|
int height, int width, int filter_depth,
|
|
|
|
|
int filter_height, int filter_width, int stride_depth,
|
|
|
|
|
int stride_height, int stride_width, int padding_depth,
|
|
|
|
|
int padding_height, int padding_width, int output_detph,
|
|
|
|
|
int output_height, int output_width, T* data_vol) {
|
|
|
|
|
int height, int width, int dilation_d, int dilation_h,
|
|
|
|
|
int dilation_w, int filter_depth, int filter_height,
|
|
|
|
|
int filter_width, int stride_depth, int stride_height,
|
|
|
|
|
int stride_width, int padding_depth, int padding_height,
|
|
|
|
|
int padding_width, int output_detph, int output_height,
|
|
|
|
|
int output_width, T* data_vol) {
|
|
|
|
|
const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
|
|
|
|
|
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
|
|
|
|
|
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
|
|
|
|
|
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
T src_val = 0;
|
|
|
|
@ -115,35 +147,42 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
|
|
|
|
|
int h = (index / width) % height + padding_height;
|
|
|
|
|
int d = (index / width / height) % depth + padding_depth;
|
|
|
|
|
int c = index / width / height / depth;
|
|
|
|
|
|
|
|
|
|
// compute the start and end of the output
|
|
|
|
|
int w_col_start =
|
|
|
|
|
(w < filter_width) ? 0 : (w - filter_width) / stride_width + 1;
|
|
|
|
|
(w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
|
|
|
|
|
int w_col_end = min(w / stride_width + 1, output_width);
|
|
|
|
|
int h_col_start =
|
|
|
|
|
(h < filter_height) ? 0 : (h - filter_height) / stride_height + 1;
|
|
|
|
|
(h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
|
|
|
|
|
int h_col_end = min(h / stride_height + 1, output_height);
|
|
|
|
|
int d_col_start =
|
|
|
|
|
(d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1;
|
|
|
|
|
(d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1;
|
|
|
|
|
int d_col_end = min(d / stride_depth + 1, output_detph);
|
|
|
|
|
|
|
|
|
|
int offset = (c * filter_depth * filter_height * filter_width +
|
|
|
|
|
d * filter_width * filter_height + h * filter_width + w) *
|
|
|
|
|
output_detph * output_height * output_width;
|
|
|
|
|
|
|
|
|
|
int coeff_d_col =
|
|
|
|
|
(1 - stride_depth * filter_width * filter_height * output_detph) *
|
|
|
|
|
output_height * output_width;
|
|
|
|
|
int coeff_h_col =
|
|
|
|
|
(1 - stride_height * filter_width * output_detph * output_height) *
|
|
|
|
|
output_width;
|
|
|
|
|
int coeff_w_col =
|
|
|
|
|
(1 - stride_width * output_detph * output_height * output_width);
|
|
|
|
|
|
|
|
|
|
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
|
|
|
|
|
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
|
|
|
|
|
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
|
|
|
|
|
src_val += data_col[offset + d_col * coeff_d_col +
|
|
|
|
|
h_col * coeff_h_col + w_col * coeff_w_col];
|
|
|
|
|
int d_off = (d - d_col * stride_depth);
|
|
|
|
|
int h_off = (h - h_col * stride_height);
|
|
|
|
|
int w_off = (w - w_col * stride_width);
|
|
|
|
|
if (d_off % dilation_d == 0 && h_off % dilation_h == 0 &&
|
|
|
|
|
w_off % dilation_w == 0) {
|
|
|
|
|
d_off /= dilation_d;
|
|
|
|
|
h_off /= dilation_h;
|
|
|
|
|
w_off /= dilation_w;
|
|
|
|
|
|
|
|
|
|
int data_col_index =
|
|
|
|
|
(((((c * filter_depth + d_off) * filter_height + h_off) *
|
|
|
|
|
filter_width +
|
|
|
|
|
w_off) *
|
|
|
|
|
output_detph +
|
|
|
|
|
d_col) *
|
|
|
|
|
output_height +
|
|
|
|
|
h_col) *
|
|
|
|
|
output_width +
|
|
|
|
|
w_col;
|
|
|
|
|
src_val += data_col[data_col_index];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -162,6 +201,7 @@ class Col2VolFunctor<platform::GPUPlace, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
framework::Tensor& vol, const framework::Tensor& col,
|
|
|
|
|
int dilation_d, int dilation_h, int dilation_w,
|
|
|
|
|
int stride_depth, int stride_height, int stride_width,
|
|
|
|
|
int padding_depth, int padding_height,
|
|
|
|
|
int padding_width) const {
|
|
|
|
@ -179,6 +219,28 @@ class Col2VolFunctor<platform::GPUPlace, T> {
|
|
|
|
|
int output_height = col.dims()[5];
|
|
|
|
|
int output_width = col.dims()[6];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth -
|
|
|
|
|
((dilation_d * (filter_depth - 1) + 1))) /
|
|
|
|
|
stride_depth +
|
|
|
|
|
1,
|
|
|
|
|
output_depth,
|
|
|
|
|
"input_depth and output_depth are "
|
|
|
|
|
"Mismatching.");
|
|
|
|
|
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height -
|
|
|
|
|
((dilation_h * (filter_height - 1) + 1))) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1,
|
|
|
|
|
output_height,
|
|
|
|
|
"input_height and output_height are "
|
|
|
|
|
"Mismatching.");
|
|
|
|
|
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width -
|
|
|
|
|
((dilation_w * (filter_width - 1) + 1))) /
|
|
|
|
|
stride_width +
|
|
|
|
|
1,
|
|
|
|
|
output_width,
|
|
|
|
|
"input_width and output_width are "
|
|
|
|
|
"Mismatching.");
|
|
|
|
|
|
|
|
|
|
int num_kernels = input_channels * input_depth * input_height * input_width;
|
|
|
|
|
|
|
|
|
|
const int threads = 1024;
|
|
|
|
@ -188,9 +250,10 @@ class Col2VolFunctor<platform::GPUPlace, T> {
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
num_kernels, col.data<T>(), input_depth, input_height, input_width,
|
|
|
|
|
filter_depth, filter_height, filter_width, stride_depth, stride_height,
|
|
|
|
|
stride_width, padding_depth, padding_height, padding_width,
|
|
|
|
|
output_depth, output_height, output_width, vol.data<T>());
|
|
|
|
|
dilation_d, dilation_h, dilation_w, filter_depth, filter_height,
|
|
|
|
|
filter_width, stride_depth, stride_height, stride_width, padding_depth,
|
|
|
|
|
padding_height, padding_width, output_depth, output_height,
|
|
|
|
|
output_width, vol.data<T>());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|