|
|
|
@ -61,10 +61,10 @@ class _ConvVariational(_Conv):
|
|
|
|
|
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
|
|
|
|
|
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
|
|
|
|
|
|
|
|
|
if not isinstance(stride, (int, tuple)):
|
|
|
|
|
if isinstance(stride, bool) or not isinstance(stride, (int, tuple)):
|
|
|
|
|
raise TypeError('The type of `stride` should be `int` of `tuple`')
|
|
|
|
|
|
|
|
|
|
if not isinstance(dilation, (int, tuple)):
|
|
|
|
|
if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)):
|
|
|
|
|
raise TypeError('The type of `dilation` should be `int` of `tuple`')
|
|
|
|
|
|
|
|
|
|
# convolution args
|
|
|
|
@ -136,8 +136,8 @@ class _ConvVariational(_Conv):
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
str_info = 'in_channels={}, out_channels={}, kernel_size={}, weight_mean={}, stride={}, pad_mode={}, ' \
|
|
|
|
|
'padding={}, dilation={}, group={}, weight_std={}, has_bias={}'\
|
|
|
|
|
str_info = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \
|
|
|
|
|
'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}'\
|
|
|
|
|
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
|
|
|
|
|
self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std,
|
|
|
|
|
self.has_bias)
|
|
|
|
|