conv_transpose supports channel_last input, test=develop, test=document_preview (#20072)

fix-python-transpose
Zhang Ting 6 years ago committed by hong
parent c9139c3db3
commit cf6919bf6e

@ -153,8 +153,8 @@ paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'moment
paddle.fluid.layers.instance_norm (ArgSpec(args=['input', 'epsilon', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None)), ('document', '02972097e089629efdb0ed9404fd36ae'))
paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787'))
paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '83e08f21af41ac8bac37aeab1f86fdd0'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'ab58296b567bf0c686084add7f3280a4'))
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'fe15dbfb17d97d3d29b2fa7ee6390ee6'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCHW')), ('document', '9391d75358b6cba0cc5d22a01a223420'))
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name', 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCDHW')), ('document', '74bce3cd4224e6ff133d54508dc7f150'))
paddle.fluid.layers.sequence_expand (ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '10e122eb755c2bd1f78ef2332b28f1a0'))
paddle.fluid.layers.sequence_expand_as (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '858c432e7cbd8bb952cc2eb555457d50'))
paddle.fluid.layers.sequence_pad (ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'df08b9c499ab3a90f95d08ab5b6c6c62'))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -22,6 +22,8 @@ namespace paddle {
namespace operators {
namespace math {
using DataLayout = framework::DataLayout;
/*
* \brief Compute the depthwise convolution which include
* forward process and backpropagation process
@ -34,7 +36,8 @@ class DepthwiseConvFunctor {
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations, framework::Tensor* output);
const std::vector<int>& dilations, framework::Tensor* output,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <typename DeviceContext, typename T,
@ -47,7 +50,8 @@ class DepthwiseConvInputGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* input_grad);
framework::Tensor* input_grad,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <typename DeviceContext, typename T,
@ -59,7 +63,8 @@ class DepthwiseConvFilterGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
framework::Tensor* filter_grad);
framework::Tensor* filter_grad,
const DataLayout data_layout = DataLayout::kNCHW);
};
} // namespace math

@ -32,7 +32,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
@ -41,16 +42,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
dilation[1] == 1) {
if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 &&
padding[3] == 0) {
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col, data_layout);
return;
} else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 &&
padding[3] == 1) {
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col);
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col, data_layout);
return;
}
// TODO(TJ): complete padding >=2
}
im2col_common<T>(im, dilation, stride, padding, col);
im2col_common<T>(im, dilation, stride, padding, col, data_layout);
}
};
@ -67,13 +68,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
@ -109,7 +114,15 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] +=
int im_offset;
if (data_layout == DataLayout::kNCHW) {
im_offset =
(c_im * im_height + im_row_idx) * im_width + im_col_idx;
} else {
im_offset =
(im_row_idx * im_width + im_col_idx) * im_channels + c_im;
}
im_data[im_offset] +=
col_data[(c * col_height + h) * col_width + w];
}
}
@ -139,7 +152,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
@ -202,7 +216,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");

@ -26,27 +26,41 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
int im_width, int dilation_h, int dilation_w,
int filter_height, int filter_width, int stride_height,
int stride_width, int padding_height, int padding_width,
int col_height, int col_width, T* data_col) {
int col_height, int col_width, T* data_col,
const DataLayout data_layout) {
int input_channels = num_outs / col_height / col_width;
int channels_col = input_channels * filter_height * filter_width;
const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < num_outs) {
int w_out = index % col_width;
int h_out = (index / col_width) % col_height;
int channel_in = index / col_width / col_height;
int w_out = (data_layout == DataLayout::kNCHW
? index % col_width
: (index / input_channels) % col_width);
int h_out = (data_layout == DataLayout::kNCHW
? (index / col_width) % col_height
: (index / input_channels / col_width) % col_height);
int channel_in =
(data_layout == DataLayout::kNCHW ? index / col_width / col_height
: index % input_channels);
int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height - padding_height;
int w_in = w_out * stride_width - padding_width;
data_col += (channel_out * col_height + h_out) * col_width + w_out;
data_im += (channel_in * im_height + h_in) * im_width + w_in;
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int rIdx = h_in + i * dilation_h;
int cIdx = w_in + j * dilation_w;
int im_idx;
if (data_layout == DataLayout::kNCHW) {
im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
} else {
im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
}
*data_col =
(rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
? 0
: data_im[i * dilation_h * im_width + j * dilation_w];
: data_im[im_idx];
data_col += col_height * col_width;
}
}
@ -65,13 +79,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3);
PADDLE_ENFORCE_EQ(col->dims().size(), 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
@ -86,7 +105,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
im2col<T><<<grid, threads, 0, context.stream()>>>(
im.data<T>(), num_outputs, im_height, im_width, dilation[0],
dilation[1], filter_height, filter_width, stride[0], stride[1],
padding[0], padding[1], col_height, col_width, col->data<T>());
padding[0], padding[1], col_height, col_width, col->data<T>(),
data_layout);
}
};
@ -95,18 +115,27 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
int dilation_h, int dilation_w, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int col_height,
int col_width, T* data_im) {
int col_width, T* data_im,
const DataLayout data_layout) {
const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
int input_channels = n / im_height / im_width;
if (index < n) {
T val = 0;
int w = index % im_width + padding_width;
int h = (index / im_width) % im_height + padding_height;
int c = index / (im_width * im_height);
int w = (data_layout == DataLayout::kNCHW
? index % im_width + padding_width
: (index / input_channels) % im_width + padding_width);
int h = (data_layout == DataLayout::kNCHW
? (index / im_width) % im_height + padding_height
: (index / input_channels / im_width) % im_height +
padding_height);
int c = (data_layout == DataLayout::kNCHW ? index / im_width / im_height
: index % input_channels);
// compute the start and end of the output
int w_col_start =
@ -151,13 +180,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3);
PADDLE_ENFORCE_EQ(col.dims().size(), 5);
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
@ -191,7 +225,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0, context.stream()>>>(
num_kernels, col.data<T>(), im_height, im_width, dilation[0],
dilation[1], filter_height, filter_width, stride[0], stride[1],
padding[0], padding[2], col_height, col_width, im->data<T>());
padding[0], padding[1], col_height, col_width, im->data<T>(),
data_layout);
}
};
@ -248,9 +283,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3);
PADDLE_ENFORCE_EQ(col->dims().size(), 5);
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col->dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
@ -330,9 +368,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3);
PADDLE_ENFORCE_EQ(col.dims().size(), 5);
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout) {
PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];

@ -23,6 +23,8 @@ namespace paddle {
namespace operators {
namespace math {
using DataLayout = framework::DataLayout;
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
enum class ColFormat { kCFO = 0, kOCF = 1 };
@ -86,7 +88,8 @@ class Im2ColFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& im,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col);
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW);
};
template <ColFormat Format, typename DeviceContext, typename T>
@ -95,7 +98,8 @@ class Col2ImFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im);
const std::vector<int>& padding, framework::Tensor* im,
const DataLayout data_layout = DataLayout::kNCHW);
};
} // namespace math

@ -30,10 +30,14 @@ inline void im2col_common(const framework::Tensor& im,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding,
framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
@ -50,8 +54,14 @@ inline void im2col_common(const framework::Tensor& im,
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < output_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_idx;
if (data_layout == DataLayout::kNCHW) {
im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
} else {
im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
}
int col_idx = (c * output_height + h) * output_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
@ -65,11 +75,15 @@ inline void im2col_common(const framework::Tensor& im,
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
*/
template <typename T>
inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
inline void im2col_sh1sw1dh1dw1ph0pw0(
const framework::Tensor& im, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
@ -89,7 +103,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
const T* src_data = src_data_ic;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
std::memcpy(dst_data, src_data + kw, copy_size);
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data, src_data + kw, copy_size);
} else {
for (int kow = 0; kow < output_width; ++kow) {
dst_data[kow] =
im_data[((oh + kh) * im_width + kw + kow) * im_channels + ic];
}
}
dst_data = dst_data + col_matrix_width;
}
src_data = src_data + im_width;
@ -107,10 +128,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
*/
template <typename T>
inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
framework::Tensor* col,
const DataLayout data_layout) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
@ -180,7 +205,17 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
dst_data = dst_data + col_matrix_width;
continue;
}
std::memcpy(dst_data + plw, src_data, copy_size);
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data + plw, src_data, copy_size);
} else {
for (int kow = 0; kow < output_width - plw - prw; ++kow) {
dst_data[plw + kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
src_data = src_data + im_width;
}
@ -226,19 +261,49 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
// TODO(TJ): reuse plw-kw outside this for
// try to unify
for (int kw = 0; kw < plw; ++kw) {
std::memcpy(dst_data + (plw - kw), src_data,
sizeof(T) * (output_width - (plw - kw)));
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data + (plw - kw), src_data,
sizeof(T) * (output_width - (plw - kw)));
} else {
for (int kow = 0; kow < output_width - (plw - kw); ++kow) {
dst_data[plw - kw + kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
}
for (int kw = plw; kw < filter_width - prw; ++kw) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * output_width);
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * output_width);
} else {
for (int kow = 0; kow < output_width; ++kow) {
dst_data[kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kw - plw + kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
}
int i = 1;
for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * (output_width - i));
if (data_layout == DataLayout::kNCHW) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * (output_width - i));
} else {
for (int kow = 0; kow < output_width - i; ++kow) {
dst_data[kow] =
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
kw - plw + kow) *
im_channels +
ic];
}
}
dst_data = dst_data + col_matrix_width;
}
src_data = src_data + im_width;

@ -32,16 +32,21 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const {
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
@ -59,6 +64,7 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
@ -97,10 +103,16 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
int vol_idx =
((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
} else {
vol_idx = ((d_pad * input_height + h_pad) * input_width + w_pad) *
input_channels +
c_in;
}
col_data[col_idx] =
(h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth)
@ -126,16 +138,21 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const {
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
int input_width = vol->dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
@ -191,11 +208,17 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
int vol_idx =
((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
} else {
vol_idx =
((d_pad * input_height + h_pad) * input_width + w_pad) *
input_channels +
cIm;
}
int col_idx =
((c * output_depth + d) * output_height + h) * output_width +
w;

@ -28,7 +28,12 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
int filter_width, int stride_depth, int stride_height,
int stride_width, int padding_depth, int padding_height,
int padding_width, int output_detph, int output_height,
int output_width, T* data_col) {
int output_width, T* data_col,
const DataLayout data_layout) {
int input_channels =
num_kernels / output_detph / output_height / output_width;
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
int w_out = index % output_width;
@ -43,18 +48,22 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
output_width +
w_out;
data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filter_depth; ++k) {
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int d = d_in + k * dilation_d;
int h = h_in + i * dilation_h;
int w = w_in + j * dilation_w;
int col_idx = (k * dilation_d * height + i * dilation_h) * width +
j * dilation_w;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
vol_idx = ((channel_in * depth + d) * height + h) * width + w;
} else {
vol_idx =
((d * height + h) * width + w) * input_channels + channel_in;
}
*data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
w < width)
? data_vol[col_idx]
? data_vol[vol_idx]
: 0;
data_col += output_detph * output_height * output_width;
}
@ -64,7 +73,10 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
}
/*
* im = [input_channels,intpu_depth, input_height, input_width]
* im = [input_channels,intpu_depth, input_height, input_width] for
* channels_first
* im = [input_depth, input_height, input_width, input_channels] for
* channels_last
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
@ -76,15 +88,21 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4);
PADDLE_ENFORCE_EQ(col->dims().size(), 7);
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
@ -130,7 +148,8 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, col->data<T>());
pad_w_left, output_depth, output_height, output_width, col->data<T>(),
data_layout);
}
};
@ -141,18 +160,27 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int filter_width, int stride_depth, int stride_height,
int stride_width, int padding_depth, int padding_height,
int padding_width, int output_detph, int output_height,
int output_width, T* data_vol) {
int output_width, T* data_vol,
const DataLayout data_layout) {
const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
int input_channels = num_kernels / depth / height / width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
T src_val = 0;
int w = index % width + padding_width;
int h = (index / width) % height + padding_height;
int d = (index / width / height) % depth + padding_depth;
int c = index / width / height / depth;
int w = (data_layout == DataLayout::kNCHW
? index % width + padding_width
: (index / input_channels) % width + padding_width);
int h = (data_layout == DataLayout::kNCHW
? (index / width) % height + padding_height
: (index / input_channels / width) % height + padding_height);
int d = (data_layout == DataLayout::kNCHW
? (index / width / height) % depth + padding_depth
: index / input_channels / width / height + padding_depth);
int c = (data_layout == DataLayout::kNCHW ? index / width / height / depth
: index % input_channels);
// compute the start and end of the output
int w_col_start =
@ -196,7 +224,10 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
}
/*
* im = [input_channels, input_depth, input_height, input_width]
* im = [input_channels,intpu_depth, input_height, input_width] for
* channels_first
* im = [input_depth, input_height, input_width, input_channels] for
* channels_last
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
@ -208,15 +239,21 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4);
PADDLE_ENFORCE_EQ(col.dims().size(), 7);
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
"The dimension of vol should be 4.");
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
"The dimension of col should be 7.");
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
int input_width = vol->dims()[3];
int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
@ -263,7 +300,8 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, vol->data<T>());
pad_w_left, output_depth, output_height, output_width, vol->data<T>(),
data_layout);
}
};

@ -22,6 +22,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
using DataLayout = framework::DataLayout;
/*
* \brief Converts the feature data of four dimensions(CDHW) into a colData of
* seven dimensions in the Vol2ColFunctor calculation,
@ -70,8 +73,8 @@ class Vol2ColFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const;
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) const;
};
template <typename DeviceContext, typename T>
@ -80,8 +83,8 @@ class Col2VolFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const;
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout = DataLayout::kNCHW) const;
};
} // namespace math

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save