|
|
|
@ -14,7 +14,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/operators/math/im2col.h"
|
|
|
|
@ -62,7 +61,8 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
|
|
|
|
|
// no paddings and groups allowed in deconv
|
|
|
|
|
// TODO(Zhuoyuan): Paddings can be added in future.
|
|
|
|
|
// groups will alway be disabled in conv2dtranspose.
|
|
|
|
|
|
|
|
|
|
const int batch_size = input->dims()[0];
|
|
|
|
|
const int m = input->dims()[1];
|
|
|
|
@ -91,7 +91,8 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
// col_matrix shares the same piece of data with col,
|
|
|
|
|
// but will be reshaped into a two-dimensional matrix shape
|
|
|
|
|
// to call the matrix multiplication interface.
|
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
|
Tensor col_matrix;
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
|
|
|
|
|
DDim output_shape = {c, o_h, o_w};
|
|
|
|
@ -100,7 +101,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
DDim filter_matrix_shape = {m, c * k_h * k_w};
|
|
|
|
|
filter.Resize(filter_matrix_shape);
|
|
|
|
|
|
|
|
|
|
// deconvolution: gemm + col2im (similar to conv-backward on input)
|
|
|
|
|
// convolution transpose: gemm + col2im (similar to conv-backward on input)
|
|
|
|
|
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*output);
|
|
|
|
@ -142,7 +143,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
context.Output<Tensor>(framework::GradVarName("Filter"));
|
|
|
|
|
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
// Actually, no paddings and groups allowed in deconv.
|
|
|
|
|
// Actually, no paddings and groups allowed in conv transpose.
|
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
|
|
|
|
|
const int batch_size = input->dims()[0];
|
|
|
|
@ -180,11 +181,12 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
DDim filter_matrix_shape = {m, c * k_h * k_w};
|
|
|
|
|
filter.Resize(filter_matrix_shape);
|
|
|
|
|
|
|
|
|
|
// deconvolution grad on input:
|
|
|
|
|
// convolution transpose grad on input:
|
|
|
|
|
// im2col + gemm (similar to conv-forward)
|
|
|
|
|
// input need to compute gradient
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
|
Tensor col_matrix;
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
DDim col_matrix_shape = {c * k_h * k_w, h * w};
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
|
|
|
|
@ -216,7 +218,8 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// filter gradient required
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
Tensor col_matrix_f = col;
|
|
|
|
|
Tensor col_matrix_f;
|
|
|
|
|
col_matrix_f.ShareDataWith(col);
|
|
|
|
|
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
|
|
|
|
|
col_matrix_f.Resize(col_matrix_shape_f);
|
|
|
|
|
|