follow comments

revert-4814-Add_sequence_project_op
chengduoZH 7 years ago
parent dc7d07358c
commit 2947f5678e

@ -42,14 +42,20 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3]; int output_height = col.dims()[3];
int output_width = col.dims()[4]; int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ(
stride_height + (input_height + padding_up + padding_down - filter_height) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / output_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = input_channels * filter_height * filter_width;
@ -62,16 +68,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset; int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset; int im_col_idx = w * stride_width + w_offset - padding_left;
if ((im_row_idx - padding_up) < 0 ||
(im_row_idx - padding_up) >= input_height || if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 ||
(im_col_idx - padding_left) < 0 || im_col_idx >= input_width) {
(im_col_idx - padding_left) >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0); col_data[(c * output_height + h) * output_width + w] = T(0);
} else { } else {
im_row_idx += c_im * input_height - padding_up; im_row_idx += c_im * input_height;
im_col_idx -= padding_left;
col_data[(c * output_height + h) * output_width + w] = col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx]; im_data[im_row_idx * input_width + im_col_idx];
} }
@ -104,14 +108,20 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3]; int output_height = col.dims()[3];
int output_width = col.dims()[4]; int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ(
stride_height + (input_height + padding_up + padding_down - filter_height) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / output_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = input_channels * filter_height * filter_width;
@ -124,14 +134,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset; int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset; int im_col_idx = w * stride_width + w_offset - padding_left;
if ((im_row_idx - padding_up) >= 0 &&
(im_row_idx - padding_up) < input_height && if ((im_row_idx) >= 0 && (im_row_idx) < input_height &&
(im_col_idx - padding_left) >= 0 && (im_col_idx) >= 0 && (im_col_idx) < input_width) {
(im_col_idx - padding_left) < input_width) { im_row_idx += c_im * input_height;
im_row_idx += c_im * input_height - padding_up;
im_col_idx -= padding_left;
im_data[im_row_idx * input_width + im_col_idx] += im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w]; col_data[(c * output_height + h) * output_width + w];
} }
@ -173,14 +181,20 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0]; int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ(
stride_height + (input_height + padding_up + padding_down - filter_height) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / output_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
@ -243,14 +257,20 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0]; int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ(
stride_height + (input_height + padding_up + padding_down - filter_height) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / output_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();

Loading…
Cancel
Save