|
|
|
@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::matmul<Place, T>(context.device_context(), filter, true,
|
|
|
|
|
input_batch, false, T(1.0), &col_matrix, T(0.0));
|
|
|
|
|
col2im(context.device_context(), output_batch, col, strides[0],
|
|
|
|
|
strides[1], 0, 0);
|
|
|
|
|
strides[1], 0, 0, 0, 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
|
|
|
|
|
im2col(context.device_context(), output_grad_batch, col, strides[0],
|
|
|
|
|
strides[1], paddings[0], paddings[1]);
|
|
|
|
|
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
|
|
|
|
|
|
|
|
|
|
// gemm: dx = filter * dy
|
|
|
|
|
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
|
|
|
|
@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// im2col: (c * h * w, k_h * k_w)
|
|
|
|
|
im2col(context.device_context(), output_grad_batch, col, strides[0],
|
|
|
|
|
strides[1], paddings[0], paddings[1]);
|
|
|
|
|
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
|
|
|
|
|
|
|
|
|
|
// gemm: d_filter = x * y_grad^T
|
|
|
|
|
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
|
|
|
|
|