|
|
|
@ -54,6 +54,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
|
|
|
|
|
dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) {
|
|
|
|
|
int col_matrix_width = output_width * output_height;
|
|
|
|
|
size_t copy_size = sizeof(T) * output_width;
|
|
|
|
|
for (int oh = 0; oh < output_height; ++oh) {
|
|
|
|
|
const T* im_data_start = im_data + oh * im_width;
|
|
|
|
|
T* dst_data = col_data + oh * output_width;
|
|
|
|
@ -61,7 +62,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
const T* src_data = im_data_start + ic * im_height * im_width;
|
|
|
|
|
for (int kh = 0; kh < filter_height; ++kh) {
|
|
|
|
|
for (int kw = 0; kw < filter_width; ++kw) {
|
|
|
|
|
std::memcpy(dst_data, src_data + kw, sizeof(T) * output_width);
|
|
|
|
|
std::memcpy(dst_data, src_data + kw, copy_size);
|
|
|
|
|
dst_data = dst_data + col_matrix_width;
|
|
|
|
|
}
|
|
|
|
|
src_data = src_data + im_width;
|
|
|
|
|