|
|
|
@ -83,7 +83,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
DDim col_shape = {C, K_H, K_W, H, W};
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
|
DDim col_matrix_shape = {M * K_H * K_W, H * W};
|
|
|
|
|
DDim col_matrix_shape = {C * K_H * K_W, H * W};
|
|
|
|
|
|
|
|
|
|
Tensor col;
|
|
|
|
|
col.mutable_data<T>(col_shape, context.GetPlace());
|
|
|
|
@ -108,11 +108,11 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
// batch with size (M, H * W)
|
|
|
|
|
Tensor input_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
|
// filter size: (M, C * K_H * K_W)
|
|
|
|
|
|
|
|
|
|
// output size: (C, O_H, O_W)
|
|
|
|
|
Tensor output_batch = output->Slice<T>(i, i + 1).Resize(output_shape);
|
|
|
|
|
|
|
|
|
|
// filter size: (Co, Ci * Hf * Wf)
|
|
|
|
|
|
|
|
|
|
// col_matrix = filter * input_batch
|
|
|
|
|
// of shape (C * K_H * K_W, H * W)
|
|
|
|
|
math::matmul<Place, T>(context.device_context(), filter, true,
|
|
|
|
@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
const Tensor* output_grad =
|
|
|
|
|
context.Input<Tensor>(framework::GradVarName("Output"));
|
|
|
|
|
|
|
|
|
|
// For filter, we do not use const pointer
|
|
|
|
|
// but we should avoid
|
|
|
|
|
// For filter, we do not use const pointer b/c we will do reshape
|
|
|
|
|
// but we should avoid modifying its value
|
|
|
|
|
Tensor filter = *context.Input<Tensor>("Filter");
|
|
|
|
|
|
|
|
|
|
Tensor* input_grad =
|
|
|
|
@ -157,7 +157,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
int O_H = output_grad->dims()[2];
|
|
|
|
|
int O_W = output_grad->dims()[3];
|
|
|
|
|
|
|
|
|
|
// Two functors required to get to the right shape
|
|
|
|
|
// Only im2col functor required for bp to get to the right shape
|
|
|
|
|
paddle::operators::math::Im2ColFunctor<
|
|
|
|
|
paddle::operators::math::ColFormat::kCFO, Place, T>
|
|
|
|
|
im2col;
|
|
|
|
@ -166,15 +166,13 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
DDim col_shape = {C, K_H, K_W, H, W};
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
|
DDim col_matrix_shape = {C * K_H * K_W, H * W};
|
|
|
|
|
DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
|
|
|
|
|
|
|
|
|
|
Tensor col;
|
|
|
|
|
col.mutable_data<T>(col_shape, context.GetPlace());
|
|
|
|
|
// col_matrix shares the same piece of data with col,
|
|
|
|
|
// but will be reshaped into a two-dimensional matrix shape
|
|
|
|
|
// to call the matrix multiplication interface.
|
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
|
|
|
|
|
DDim output_shape = {C, O_H, O_W};
|
|
|
|
|
DDim input_matrix_shape = {M, H * W};
|
|
|
|
@ -186,6 +184,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
// im2col + gemm (similar to conv-forward)
|
|
|
|
|
// input need to compute gradient
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
Tensor col_matrix = col;
|
|
|
|
|
DDim col_matrix_shape = {C * K_H * K_W, H * W};
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
|
|
|
|
|
input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto t = framework::EigenVector<T>::Flatten(*input_grad);
|
|
|
|
|
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
|
|
|
@ -194,14 +196,18 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
// batch with size (C, O_H * O_W)
|
|
|
|
|
Tensor output_grad_batch =
|
|
|
|
|
output_grad->Slice<T>(i, i + 1).Resize(output_shape);
|
|
|
|
|
// filter of size (M, C * K_H * K_W)
|
|
|
|
|
|
|
|
|
|
// batch with size (M, H, W)
|
|
|
|
|
Tensor input_grad_batch =
|
|
|
|
|
input_grad->Slice<T>(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
|
|
|
|
|
|
// im2col: (C * K_H * K_W, H * W)
|
|
|
|
|
// im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W)
|
|
|
|
|
im2col(context.device_context(), output_grad_batch, col_matrix,
|
|
|
|
|
strides[0], strides[1], paddings[0], paddings[1]);
|
|
|
|
|
|
|
|
|
|
// gemm: dx = filter * dy
|
|
|
|
|
// (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H)
|
|
|
|
|
math::matmul<Place, T>(context.device_context(), filter, false,
|
|
|
|
|
col_matrix, false, T(1.0), &input_grad_batch,
|
|
|
|
|
T(0.0));
|
|
|
|
@ -210,6 +216,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// filter gradient required
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
Tensor col_matrix_f = col;
|
|
|
|
|
DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
|
|
|
|
|
col_matrix_f.Resize(col_matrix_shape_f);
|
|
|
|
|
|
|
|
|
|
filter_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
Tensor filter_grad_ = *filter_grad;
|
|
|
|
|
filter_grad_.Resize(filter_matrix_shape);
|
|
|
|
@ -223,10 +233,12 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
// input batch
|
|
|
|
|
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
|
|
|
|
|
|
|
|
|
|
// im2col: (C * K_H * K_W, H * W)
|
|
|
|
|
im2col(context.device_context(), output_grad_batch, col_matrix,
|
|
|
|
|
// im2col: (C * H * W, K_H * K_W)
|
|
|
|
|
im2col(context.device_context(), output_grad_batch, col_matrix_f,
|
|
|
|
|
strides[0], strides[1], paddings[0], paddings[1]);
|
|
|
|
|
|
|
|
|
|
// gemm: d_filter = x * y_grad^T
|
|
|
|
|
// (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H)
|
|
|
|
|
math::matmul<Place, T>(context.device_context(), in_batch, false,
|
|
|
|
|
col_matrix, true, T(1.0), &filter_grad_, T(1.0));
|
|
|
|
|
}
|
|
|
|
|