|
|
|
@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
// Base convolution operator definations for other conv
|
|
|
|
|
// like operators to reuse the implementation.
|
|
|
|
|
inline int OutputSize(int input_size, int filter_size, int padding,
|
|
|
|
|
int stride) {
|
|
|
|
|
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
|
|
|
|
|
inline int OutputSize(int input_size, int filter_size, int dilation,
|
|
|
|
|
int padding_up, int padding_down, int stride) {
|
|
|
|
|
int output_size = (input_size + padding_up + padding_down -
|
|
|
|
|
(dilation * (filter_size - 1) + 1)) /
|
|
|
|
|
stride +
|
|
|
|
|
1;
|
|
|
|
|
return output_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
int groups = context.Attr<int>("groups");
|
|
|
|
|
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
|
|
|
|
|
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
|
|
|
|
|
@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (filter_shape_vec.size() == 2) {
|
|
|
|
|
// im2col
|
|
|
|
|
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
|
|
|
|
|
im2col(context.device_context(), in_slice, col, strides[0],
|
|
|
|
|
strides[1], paddings[0], paddings[0], paddings[1],
|
|
|
|
|
paddings[1]);
|
|
|
|
|
im2col(context.device_context(), in_slice, col, dilations[0],
|
|
|
|
|
dilations[1], strides[0], strides[1], paddings[0], paddings[0],
|
|
|
|
|
paddings[1], paddings[1]);
|
|
|
|
|
} else if (filter_shape_vec.size() == 3) {
|
|
|
|
|
// vol2col
|
|
|
|
|
math::Vol2ColFunctor<Place, T> vol2col;
|
|
|
|
@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
int groups = context.Attr<int>("groups");
|
|
|
|
|
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
|
|
|
|
|
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
|
|
|
|
|
@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
if (filter_shape_vec.size() == 2) {
|
|
|
|
|
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
|
|
|
|
|
col2im(context.device_context(), in_grad_slice, col, strides[0],
|
|
|
|
|
strides[1], paddings[0], paddings[0], paddings[1],
|
|
|
|
|
paddings[1]);
|
|
|
|
|
col2im(context.device_context(), in_grad_slice, col, dilations[0],
|
|
|
|
|
dilations[1], strides[0], strides[1], paddings[0],
|
|
|
|
|
paddings[0], paddings[1], paddings[1]);
|
|
|
|
|
|
|
|
|
|
} else if (filter_shape_vec.size() == 3) {
|
|
|
|
|
math::Col2VolFunctor<Place, T> col2vol;
|
|
|
|
@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
if (filter_shape_vec.size() == 2) {
|
|
|
|
|
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
|
|
|
|
|
im2col(context.device_context(), in_slice, col, strides[0],
|
|
|
|
|
strides[1], paddings[0], paddings[0], paddings[1],
|
|
|
|
|
paddings[1]);
|
|
|
|
|
im2col(context.device_context(), in_slice, col, dilations[0],
|
|
|
|
|
dilations[1], strides[0], strides[1], paddings[0],
|
|
|
|
|
paddings[0], paddings[1], paddings[1]);
|
|
|
|
|
} else if (filter_shape_vec.size() == 3) {
|
|
|
|
|
math::Vol2ColFunctor<Place, T> vol2col;
|
|
|
|
|
vol2col(context.device_context(), in_slice, col, strides[0],
|
|
|
|
|