|
|
|
@ -158,9 +158,6 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
int O_W = output_grad->dims()[3];
|
|
|
|
|
|
|
|
|
|
// Two functors required to get to the right shape
|
|
|
|
|
paddle::operators::math::Col2ImFunctor<
|
|
|
|
|
paddle::operators::math::ColFormat::kCFO, Place, T>
|
|
|
|
|
col2im;
|
|
|
|
|
paddle::operators::math::Im2ColFunctor<
|
|
|
|
|
paddle::operators::math::ColFormat::kCFO, Place, T>
|
|
|
|
|
im2col;
|
|
|
|
@ -231,7 +228,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
|
|
|
|
|
strides[0], strides[1], paddings[0], paddings[1]);
|
|
|
|
|
// gemm: d_filter = x * y_grad^T
|
|
|
|
|
math::matmul<Place, T>(context.device_context(), in_batch, false,
|
|
|
|
|
col_matrix, true, T(1.0), &filter_grad, T(1.0));
|
|
|
|
|
col_matrix, true, T(1.0), &filter_grad_, T(1.0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|