!906 fix a bug that support dilation greater than 1 in conv2dTranspose ops

Merge pull request !906 from yangyongjie/master
pull/906/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 8a45ab1125

@ -358,6 +358,7 @@ class Conv2dTranspose(_Conv):
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
"""Calculate the width and height of output."""
length = 0
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
if self.is_valid:
if filter_size - stride_size > 0:
length = input_length * stride_size + filter_size - stride_size
@ -366,8 +367,7 @@ class Conv2dTranspose(_Conv):
elif self.is_same:
length = input_length * stride_size
elif self.is_pad:
length = input_length * stride_size - 2 * self.padding + filter_size + \
(filter_size - 1) * (dilation_size - 1) - stride_size
length = input_length * stride_size - 2 * self.padding + filter_size - stride_size
return length

@ -1186,16 +1186,18 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
kernel_w = self.kernel_size[1]
stride_h = self.stride[0]
stride_w = self.stride[1]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
# default pad mode is valid
pad_list = (0, 0, 0, 0)
if self.pad_list:
pad_list = tuple(self.pad_list)
elif self.pad_mode == "SAME":
pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + kernel_h - x_size_v[2])
pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + kernel_w - x_size_v[3])
pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
pad_list = (pad_top, pad_bottom, pad_left, pad_right)

@ -190,3 +190,15 @@ def test_compile_transpose_stride2():
net = NetConv2dTranspose(3, 64, 4, stride=2, weight_init='normal')
input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))
net(input_data)
def test_compile_transpose_dilation_2():
net = NetConv2dTranspose(3, 64, 4, stride=2, dilation=2, pad_mode='same', weight_init='normal')
input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))
net(input_data)
def test_compile_transpose_dilation_2_pad_mode_pad():
net = NetConv2dTranspose(3, 64, 4, stride=2, dilation=2, pad_mode='pad', weight_init='normal')
input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))
net(input_data)

Loading…
Cancel
Save