|
|
|
@ -261,8 +261,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
set_zero(dev_ctx, input_grad, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
// if is_expand is false, the operation of set_zero is unnecessary,
|
|
|
|
|
// because math::matmul will reset input_grad.
|
|
|
|
|
if (is_expand) {
|
|
|
|
|
set_zero(dev_ctx, input_grad, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
math::Col2VolFunctor<DeviceContext, T> col2vol;
|
|
|
|
|
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
|
|
|
|
|
|
|
|
|
|