|
|
|
@ -66,8 +66,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
const framework::Tensor& im, framework::Tensor& col,
|
|
|
|
|
int stride_height, int stride_width, int padding_height,
|
|
|
|
|
int padding_width) {
|
|
|
|
|
int stride_height, int stride_width, int padding_up,
|
|
|
|
|
int padding_down, int padding_left, int padding_right) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
|
|
|
|
@ -79,6 +79,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
int output_height = col.dims()[3];
|
|
|
|
|
int output_width = col.dims()[4];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1 ==
|
|
|
|
|
output_height);
|
|
|
|
|
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
|
|
|
|
|
stride_width +
|
|
|
|
|
1 ==
|
|
|
|
|
output_width);
|
|
|
|
|
|
|
|
|
|
int num_outputs = input_channels * output_height * output_width;
|
|
|
|
|
int blocks = (num_outputs + 1024 - 1) / 1024;
|
|
|
|
|
int block_x = 512;
|
|
|
|
@ -89,8 +98,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
im.data<T>(), num_outputs, input_height, input_width, filter_height,
|
|
|
|
|
filter_width, stride_height, stride_width, padding_height,
|
|
|
|
|
padding_width, output_height, output_width, col.data<T>());
|
|
|
|
|
filter_width, stride_height, stride_width, padding_up, padding_left,
|
|
|
|
|
output_height, output_width, col.data<T>());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -152,7 +161,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
|
|
|
|
|
const framework::Tensor& col, int stride_height,
|
|
|
|
|
int stride_width, int padding_height, int padding_width) {
|
|
|
|
|
int stride_width, int padding_up, int padding_down,
|
|
|
|
|
int padding_left, int padding_right) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
|
|
|
|
@ -164,8 +174,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
int output_height = col.dims()[3];
|
|
|
|
|
int output_width = col.dims()[4];
|
|
|
|
|
|
|
|
|
|
size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
|
|
|
|
|
(input_width + 2 * padding_width);
|
|
|
|
|
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1 ==
|
|
|
|
|
output_height);
|
|
|
|
|
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
|
|
|
|
|
stride_width +
|
|
|
|
|
1 ==
|
|
|
|
|
output_width);
|
|
|
|
|
|
|
|
|
|
size_t num_kernels = input_channels *
|
|
|
|
|
(input_height + padding_up + padding_down) *
|
|
|
|
|
(input_width + padding_left + padding_right);
|
|
|
|
|
|
|
|
|
|
size_t blocks = (num_kernels + 1024 - 1) / 1024;
|
|
|
|
|
size_t block_x = 512;
|
|
|
|
@ -178,10 +198,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
col2im<T><<<grid, threads, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
num_kernels, col.data<T>(), input_height + 2 * padding_height,
|
|
|
|
|
input_width + 2 * padding_width, input_channels, filter_height,
|
|
|
|
|
filter_width, stride_height, stride_width, padding_height,
|
|
|
|
|
padding_width, output_height, output_width, im.data<T>());
|
|
|
|
|
num_kernels, col.data<T>(), input_height + padding_up + padding_down,
|
|
|
|
|
input_width + padding_left + padding_left, input_channels,
|
|
|
|
|
filter_height, filter_width, stride_height, stride_width, padding_up,
|
|
|
|
|
padding_left, output_height, output_width, im.data<T>());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -199,8 +219,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
|
|
|
|
|
int input_height, int input_width, int filter_height,
|
|
|
|
|
int filter_width, int stride_height, int stride_width,
|
|
|
|
|
int padding_height, int padding_width,
|
|
|
|
|
int output_height, int output_width, int row_begin,
|
|
|
|
|
int row_end) {
|
|
|
|
|
int output_height, int output_width) {
|
|
|
|
|
int swid = blockIdx.x;
|
|
|
|
|
int shid = blockIdx.y;
|
|
|
|
|
for (int channelid = threadIdx.z; channelid < input_channels;
|
|
|
|
@ -208,8 +227,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
|
|
|
|
|
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
|
|
|
|
|
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
|
|
|
|
|
int width_offset = idx + swid * stride_width - padding_width;
|
|
|
|
|
int height_offset =
|
|
|
|
|
idy + (shid + row_begin) * stride_height - padding_height;
|
|
|
|
|
int height_offset = idy + shid * stride_height - padding_height;
|
|
|
|
|
int im_offset = width_offset + height_offset * input_width +
|
|
|
|
|
channelid * input_height * input_width;
|
|
|
|
|
|
|
|
|
@ -240,8 +258,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
const framework::Tensor& im, framework::Tensor& col,
|
|
|
|
|
int stride_height, int stride_width, int up_pad,
|
|
|
|
|
int down_pad) {
|
|
|
|
|
int stride_height, int stride_width, int padding_up,
|
|
|
|
|
int padding_down, int padding_left, int padding_right) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
int input_channels = im.dims()[0];
|
|
|
|
@ -249,22 +267,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
int input_width = im.dims()[2];
|
|
|
|
|
int filter_height = col.dims()[3];
|
|
|
|
|
int filter_width = col.dims()[4];
|
|
|
|
|
|
|
|
|
|
int row_begin, row_end;
|
|
|
|
|
int padding_height = std::max(up_pad, down_pad);
|
|
|
|
|
int padding_width = 0;
|
|
|
|
|
if (up_pad >= down_pad) {
|
|
|
|
|
row_begin = 0;
|
|
|
|
|
} else {
|
|
|
|
|
row_begin = down_pad - up_pad;
|
|
|
|
|
}
|
|
|
|
|
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1);
|
|
|
|
|
|
|
|
|
|
int output_height = row_end - row_begin; // col.dims()[0];
|
|
|
|
|
int output_height = col.dims()[0];
|
|
|
|
|
int output_width = col.dims()[1];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1 ==
|
|
|
|
|
output_height);
|
|
|
|
|
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
|
|
|
|
|
stride_width +
|
|
|
|
|
1 ==
|
|
|
|
|
output_width);
|
|
|
|
|
|
|
|
|
|
int block_dim_x = 0;
|
|
|
|
|
int block_dim_y = 0;
|
|
|
|
|
if (filter_height <= 4 && filter_width <= 4) {
|
|
|
|
@ -289,9 +303,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
|
|
|
|
|
filter_height, filter_width, stride_height, stride_width,
|
|
|
|
|
padding_height, padding_width, output_height, output_width, row_begin,
|
|
|
|
|
row_end);
|
|
|
|
|
filter_height, filter_width, stride_height, stride_width, padding_up,
|
|
|
|
|
padding_left, output_height, output_width);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -300,8 +313,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
|
|
|
|
|
int input_height, int input_width, int filter_height,
|
|
|
|
|
int filter_width, int stride_height, int stride_width,
|
|
|
|
|
int padding_height, int padding_width,
|
|
|
|
|
int output_height, int output_width, int row_begin,
|
|
|
|
|
int row_end) {
|
|
|
|
|
int output_height, int output_width) {
|
|
|
|
|
int swid = blockIdx.x;
|
|
|
|
|
int shid = blockIdx.y;
|
|
|
|
|
for (int channelid = threadIdx.z; channelid < input_channels;
|
|
|
|
@ -309,8 +321,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
|
|
|
|
|
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
|
|
|
|
|
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
|
|
|
|
|
int width_offset = idx + swid * stride_width - padding_width;
|
|
|
|
|
int height_offset =
|
|
|
|
|
idy + (shid + row_begin) * stride_height - padding_height;
|
|
|
|
|
int height_offset = idy + shid * stride_height - padding_height;
|
|
|
|
|
int im_offset = width_offset + height_offset * input_width +
|
|
|
|
|
channelid * input_height * input_width;
|
|
|
|
|
|
|
|
|
@ -340,7 +351,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
|
|
|
|
|
const framework::Tensor& col, int stride_height,
|
|
|
|
|
int stride_width, int up_pad, int down_pad) {
|
|
|
|
|
int stride_width, int padding_up, int padding_down,
|
|
|
|
|
int padding_left, int padding_right) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
int input_channels = im.dims()[0];
|
|
|
|
@ -348,22 +360,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
int input_width = im.dims()[2];
|
|
|
|
|
int filter_height = col.dims()[3];
|
|
|
|
|
int filter_width = col.dims()[4];
|
|
|
|
|
|
|
|
|
|
int row_begin, row_end;
|
|
|
|
|
int padding_height = std::max(up_pad, down_pad);
|
|
|
|
|
int padding_width = 0;
|
|
|
|
|
if (up_pad >= down_pad) {
|
|
|
|
|
row_begin = 0;
|
|
|
|
|
} else {
|
|
|
|
|
row_begin = down_pad - up_pad;
|
|
|
|
|
}
|
|
|
|
|
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1);
|
|
|
|
|
|
|
|
|
|
int output_height = row_end - row_begin; // col.dims()[0];
|
|
|
|
|
int output_height = col.dims()[0];
|
|
|
|
|
int output_width = col.dims()[1];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
|
|
|
|
|
stride_height +
|
|
|
|
|
1 ==
|
|
|
|
|
output_height);
|
|
|
|
|
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
|
|
|
|
|
stride_width +
|
|
|
|
|
1 ==
|
|
|
|
|
output_width);
|
|
|
|
|
|
|
|
|
|
int block_dim_x = 0;
|
|
|
|
|
int block_dim_y = 0;
|
|
|
|
|
if (filter_height <= 4 && filter_width <= 4) {
|
|
|
|
@ -388,9 +396,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
|
|
|
|
|
filter_height, filter_width, stride_height, stride_width,
|
|
|
|
|
padding_height, padding_width, output_height, output_width, row_begin,
|
|
|
|
|
row_end);
|
|
|
|
|
filter_height, filter_width, stride_height, stride_width, padding_up,
|
|
|
|
|
padding_left, output_height, output_width);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|