diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 730b5e3398..a102a394be 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -358,6 +358,7 @@ class Conv2dTranspose(_Conv): def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size): """Calculate the width and height of output.""" length = 0 + filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) if self.is_valid: if filter_size - stride_size > 0: length = input_length * stride_size + filter_size - stride_size @@ -366,8 +367,7 @@ class Conv2dTranspose(_Conv): elif self.is_same: length = input_length * stride_size elif self.is_pad: - length = input_length * stride_size - 2 * self.padding + filter_size + \ - (filter_size - 1) * (dilation_size - 1) - stride_size + length = input_length * stride_size - 2 * self.padding + filter_size - stride_size return length diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a2b6457fbd..049a0aa34d 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1186,16 +1186,18 @@ class Conv2DBackpropInput(PrimitiveWithInfer): kernel_w = self.kernel_size[1] stride_h = self.stride[0] stride_w = self.stride[1] + dilation_h = self.dilation[2] + dilation_w = self.dilation[3] # default pad mode is valid pad_list = (0, 0, 0, 0) if self.pad_list: pad_list = tuple(self.pad_list) elif self.pad_mode == "SAME": - pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + kernel_h - x_size_v[2]) + pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2]) pad_top = math.floor(pad_needed_h / 2) pad_bottom = pad_needed_h - pad_top - pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + kernel_w - x_size_v[3]) + pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3]) pad_left = math.floor(pad_needed_w / 2) pad_right = pad_needed_w - pad_left pad_list = (pad_top, pad_bottom, pad_left, pad_right) diff --git a/tests/ut/python/nn/test_conv.py b/tests/ut/python/nn/test_conv.py index a76c789a79..360aab6733 100644 --- a/tests/ut/python/nn/test_conv.py +++ b/tests/ut/python/nn/test_conv.py @@ -190,3 +190,15 @@ def test_compile_transpose_stride2(): net = NetConv2dTranspose(3, 64, 4, stride=2, weight_init='normal') input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)) net(input_data) + + +def test_compile_transpose_dilation_2(): + net = NetConv2dTranspose(3, 64, 4, stride=2, dilation=2, pad_mode='same', weight_init='normal') + input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)) + net(input_data) + + +def test_compile_transpose_dilation_2_pad_mode_pad(): + net = NetConv2dTranspose(3, 64, 4, stride=2, dilation=2, pad_mode='pad', weight_init='normal') + input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)) + net(input_data)