|
|
|
@ -126,11 +126,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
// padding_size > 1
|
|
|
|
|
for (int ic = 0; ic < im_channels; ++ic) {
|
|
|
|
|
// TODO(TJ): use add and resue stride
|
|
|
|
|
T* dst_data_ic =
|
|
|
|
|
col_data + ic * filter_width * filter_height * col_matrix_width;
|
|
|
|
|
T* dst_data_ic = col_data + ic * col_block_ic;
|
|
|
|
|
for (int kh = 0; kh < filter_height; ++kh) {
|
|
|
|
|
T* dst_data_kh =
|
|
|
|
|
dst_data_ic + kh * filter_width * col_matrix_width;
|
|
|
|
|
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
|
|
|
|
|
for (int kw = 0; kw < plw; ++kw) {
|
|
|
|
|
// TODO(TJ): reuse array outside this for
|
|
|
|
|
size_t sz = sizeof(T) * (plw - kw);
|
|
|
|
@ -158,6 +156,67 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// fill im_data
|
|
|
|
|
// padding cover two cases:
|
|
|
|
|
// 1. kw > 2*pw: kw = 3, pw = 1
|
|
|
|
|
// 0 x x x x ... x x x x 0
|
|
|
|
|
// 1 1 1 1 1 1
|
|
|
|
|
// ==>
|
|
|
|
|
// 0 x ... x x
|
|
|
|
|
// x x ... x x
|
|
|
|
|
// x x ... x 0
|
|
|
|
|
// 2. kw < 2*pw: kw = 3, pw = 2
|
|
|
|
|
// 0 0 x x x ... x x x 0 0
|
|
|
|
|
// 1 1 1 1 1 1
|
|
|
|
|
// ==>
|
|
|
|
|
// 0 0 x ... x x x
|
|
|
|
|
// 0 x x ... x x 0
|
|
|
|
|
// x x x ... x 0 0
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
|
|
|
|
|
// (output_width-1)}
|
|
|
|
|
// length of copy_size is equal kw.
|
|
|
|
|
if (plw + prw < filter_width) {
|
|
|
|
|
for (int oh = 0; oh < output_height; ++oh) {
|
|
|
|
|
const T* im_data_start =
|
|
|
|
|
im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
|
|
|
|
|
T* dst_data = col_data + oh * output_width;
|
|
|
|
|
for (int ic = 0; ic < im_channels; ++ic) {
|
|
|
|
|
const T* src_data = im_data_start + ic * im_size;
|
|
|
|
|
for (int kh = 0; kh < filter_height; ++kh) {
|
|
|
|
|
if ((oh < plh && kh < plh) ||
|
|
|
|
|
(oh > (output_height - prh - 1) &&
|
|
|
|
|
kh > (filter_height - prh - 1))) {
|
|
|
|
|
dst_data = dst_data + filter_width * col_matrix_width;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// TODO(TJ): reuse plw-kw outside this for
|
|
|
|
|
// try to unify
|
|
|
|
|
for (int kw = 0; kw < plw; ++kw) {
|
|
|
|
|
std::memcpy(dst_data + (plw - kw), src_data,
|
|
|
|
|
sizeof(T) * (output_width - (plw - kw)));
|
|
|
|
|
dst_data = dst_data + col_matrix_width;
|
|
|
|
|
}
|
|
|
|
|
for (int kw = plw; kw < filter_width - prw; ++kw) {
|
|
|
|
|
std::memcpy(dst_data, src_data + (kw - plw),
|
|
|
|
|
sizeof(T) * output_width);
|
|
|
|
|
dst_data = dst_data + col_matrix_width;
|
|
|
|
|
}
|
|
|
|
|
int i = 1;
|
|
|
|
|
for (int kw = filter_width - prw; kw < filter_width;
|
|
|
|
|
++kw, ++i) {
|
|
|
|
|
std::memcpy(dst_data, src_data + (kw - plw),
|
|
|
|
|
sizeof(T) * (output_width - i));
|
|
|
|
|
dst_data = dst_data + col_matrix_width;
|
|
|
|
|
}
|
|
|
|
|
src_data = src_data + im_width;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
LOG(FATAL) << "Not implement yet";
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|