!4718 fix avgpoolgrad

Merge pull request !4718 from fangzehua/avg0819
pull/4718/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 09fd453d61

@ -232,6 +232,8 @@ def _get_mean_matrix(x_shape, ksize, stride, padding, x_dtype):
n_input, c_input, h_input, w_input = x_shape
h_ksize, w_ksize = ksize[2], ksize[3]
if h_ksize == h_input and w_ksize == w_input and padding == "VALID":
return None
h_stride, w_stride = stride[2], stride[3]
n_output = n_input
c_output = c_input
@ -268,7 +270,11 @@ def _get_mean_matrix(x_shape, ksize, stride, padding, x_dtype):
@constexpr
def _get_kernel_matrix(kernel_matrix_shape, x_dtype):
def _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, padding, x_dtype):
if x_shape_nchw[2] == kernel_matrix_shape[2] \
and x_shape_nchw[3] == kernel_matrix_shape[3] \
and padding == 'VALID':
return None
kernel_matrix = np.ones(kernel_matrix_shape)
return Tensor(kernel_matrix, x_dtype)
@ -319,7 +325,7 @@ def get_bprop_avg_pool_grad(self):
k_size_nchw[2],
k_size_nchw[3])
mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, padding, x_dtype)
kernel_matrix = _get_kernel_matrix(kernel_matrix_shape, x_dtype)
kernel_matrix = _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, padding, x_dtype)
dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix)
return (dx,)

Loading…
Cancel
Save