|
|
|
@ -63,6 +63,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
|
|
|
|
|
// groups will alway be disabled in conv2dtranspose.
|
|
|
|
|
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
@ -114,7 +115,6 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
|
|
|
|
|
math::Col2VolFunctor<Place, T> col2vol;
|
|
|
|
|
std::vector<int> dilations({1, 1, 1});
|
|
|
|
|
|
|
|
|
|
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
|
|
|
|
|
// on input)
|
|
|
|
@ -134,8 +134,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (data_dim == 2U) {
|
|
|
|
|
// col2im: col_matrix -> dy
|
|
|
|
|
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
|
|
|
|
|
col2im(context.device_context(), col,
|
|
|
|
|
std::vector<int>{dilations[0], dilations[1]}, strides,
|
|
|
|
|
col2im(context.device_context(), col, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&output_batch);
|
|
|
|
@ -168,6 +167,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
|
|
|
|
|
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
|
|
|
|
|
@ -221,7 +221,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
|
|
|
|
|
math::Vol2ColFunctor<Place, T> vol2col;
|
|
|
|
|
std::vector<int> dilations({1, 1, 1});
|
|
|
|
|
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
@ -242,10 +241,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (data_dim == 2U) {
|
|
|
|
|
// im2col: dy -> col matrix
|
|
|
|
|
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
|
|
|
|
|
im2col(context.device_context(), output_grad_batch,
|
|
|
|
|
std::vector<int>{dilations[0], dilations[1]}, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
im2col(context.device_context(), output_grad_batch, dilations,
|
|
|
|
|
strides, std::vector<int>{paddings[0], paddings[1],
|
|
|
|
|
paddings[0], paddings[1]},
|
|
|
|
|
&col);
|
|
|
|
|
} else if (data_dim == 3U) {
|
|
|
|
|
// vol2col: dy -> col_matrix
|
|
|
|
|