|
|
|
@ -34,16 +34,15 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, framework::Tensor* col,
|
|
|
|
|
const DataLayout data_layout) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
vol.dims().size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument("The dimension of"
|
|
|
|
|
" vol should be 4, but received %d.",
|
|
|
|
|
vol.dims().size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
col->dims().size(), 7,
|
|
|
|
|
platform::errors::InvalidArgument("The dimension of"
|
|
|
|
|
"col should be 7, but received %d.",
|
|
|
|
|
col->dims().size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of vol should be 4, but received %d.",
|
|
|
|
|
vol.dims().size()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of col should be 7, but received %d.",
|
|
|
|
|
col->dims().size()));
|
|
|
|
|
|
|
|
|
|
int input_channels =
|
|
|
|
|
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
|
|
|
|
@ -152,16 +151,15 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, framework::Tensor* vol,
|
|
|
|
|
const DataLayout data_layout) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
vol->dims().size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument("The dimension of vol"
|
|
|
|
|
" should be 4, but received %d.",
|
|
|
|
|
vol->dims().size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
col.dims().size(), 7,
|
|
|
|
|
platform::errors::InvalidArgument("The dimension of col"
|
|
|
|
|
" should be 7, but received %d.",
|
|
|
|
|
col.dims().size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of vol should be 4, but received %d.",
|
|
|
|
|
vol->dims().size()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of col should be 7, but received %d.",
|
|
|
|
|
col.dims().size()));
|
|
|
|
|
|
|
|
|
|
int input_channels =
|
|
|
|
|
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
|
|
|
|
@ -192,29 +190,29 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
((dilations[0] * (filter_depth - 1) + 1))) /
|
|
|
|
|
strides[0] +
|
|
|
|
|
1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input_depth(%d)"
|
|
|
|
|
" and output_depth(%d) are mismatching.",
|
|
|
|
|
input_depth_tmp, output_depth));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_depth_tmp, output_depth,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input_depth(%d) and output_depth(%d) are mismatching.",
|
|
|
|
|
input_depth_tmp, output_depth));
|
|
|
|
|
auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
|
|
|
|
|
((dilations[1] * (filter_height - 1) + 1))) /
|
|
|
|
|
strides[1] +
|
|
|
|
|
1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_height_tmp, output_height,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input_height(%d)"
|
|
|
|
|
" and output_height(%d) are mismatching.",
|
|
|
|
|
input_height_tmp, output_height));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_height_tmp, output_height,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input_height(%d) and output_height(%d) are mismatching.",
|
|
|
|
|
input_height_tmp, output_height));
|
|
|
|
|
auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
|
|
|
|
|
((dilations[2] * (filter_width - 1) + 1))) /
|
|
|
|
|
strides[2] +
|
|
|
|
|
1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_width_tmp, output_width,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input_width(%d)"
|
|
|
|
|
" and output_width(%d) are mismatching.",
|
|
|
|
|
input_width_tmp, output_width));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_width_tmp, output_width,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input_width(%d) and output_width(%d) are mismatching.",
|
|
|
|
|
input_width_tmp, output_width));
|
|
|
|
|
T* vol_data = vol->data<T>();
|
|
|
|
|
const T* col_data = col.data<T>();
|
|
|
|
|
|
|
|
|
|