refine conv

del_some_in_makelist
chengduoZH 7 years ago
parent 5ba231d80b
commit a6ef875885

@ -260,8 +260,11 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
// if is_expand is false, the operation of set_zero is unnecessary,
// because math::matmul will reset input_grad.
if (is_expand) {
set_zero(context.device_context(), input_grad, static_cast<T>(0));
}
math::Col2VolFunctor<Place, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;

@ -225,7 +225,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
// set_zero is unnecessary, math::matmul will reset input_grad.
}
if (filter_grad) { // filter size (m, c, k_h, k_w)
filter_grad->mutable_data<T>(context.GetPlace());

Loading…
Cancel
Save