fixed Conv1dTranspose

pull/3198/head
jiangjinsheng 5 years ago
parent 8ee091fb92
commit 0397b609b3

@ -742,10 +742,10 @@ class Conv1dTranspose(_Conv):
self.padding[0] + self.padding[1])
w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1],
self.padding[2] + self.padding[3])
if self.has_bias:
return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)),
self.bias)
output = self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out))
if self.has_bias:
output = self.bias_add(output, self.bias)
if len(x_shape) == 3:
output = self.squeeze(output)
return output

@ -283,6 +283,7 @@ class AvgPool1d(_PoolNd):
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.slice = P.Slice()
self.expand = P.ExpandDims()
self.squeeze = P.Squeeze(2)
def construct(self, x):
_shape_check(self.shape(x))
@ -295,4 +296,5 @@ class AvgPool1d(_PoolNd):
else:
x = self.expand(x, 2)
x = self.avg_pool(x)
x = self.squeeze(x)
return x

Loading…
Cancel
Save