!4881 Fix param check

Merge pull request !4881 from byweng/fix_param_check
pull/4881/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f69b4e03b7

@ -65,9 +65,9 @@ class WithBNNLossCell:
""" """
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)): if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`') raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)): if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`') raise TypeError('The type of `bnn_factor` should be `int` or `float`')
self.backbone = backbone self.backbone = backbone

@ -173,13 +173,12 @@ class ConvReparam(_ConvVariational):
r""" r"""
Convolutional variational layers with Reparameterization. Convolutional variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
<https://arxiv.org/abs/1312.6114>`
Args: Args:
in_channels (int): The number of input channel :math:`C_{in}`. in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`. out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or kernel_size (Union[int, tuple[int]]): The data type is int or
tuple with 2 integers. Specifies the height and width of the 2D tuple with 2 integers. Specifies the height and width of the 2D
convolution window. Single int means the value if for both convolution window. Single int means the value if for both
height and width of the kernel. A tuple of 2 ints means the height and width of the kernel. A tuple of 2 ints means the

@ -132,8 +132,7 @@ class DenseReparam(_DenseVariational):
r""" r"""
Dense variational layers with Reparameterization. Dense variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
<https://arxiv.org/abs/1312.6114>`
Applies dense-connected layer for the input. This layer implements the operation as: Applies dense-connected layer for the input. This layer implements the operation as:

@ -78,16 +78,17 @@ class NormalPosterior(Cell):
if not isinstance(shape, (tuple, list)): if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`') raise TypeError('The type of `shape` should be `tuple` or `list`')
if not isinstance(loc_mean, (int, float)): if isinstance(loc_mean, bool) or not isinstance(loc_mean, (int, float)):
raise TypeError('The type of `loc_mean` should be `int` or `float`') raise TypeError('The type of `loc_mean` should be `int` or `float`')
if not isinstance(untransformed_scale_mean, (int, float)): if isinstance(untransformed_scale_mean, bool) or not isinstance(untransformed_scale_mean, (int, float)):
raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`') raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`')
if not (isinstance(loc_std, (int, float)) and loc_std >= 0): if isinstance(loc_std, bool) or not (isinstance(loc_std, (int, float)) and loc_std >= 0):
raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0') raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0')
if not (isinstance(untransformed_scale_std, (int, float)) and untransformed_scale_std >= 0): if isinstance(loc_std, bool) or not (isinstance(untransformed_scale_std, (int, float)) and
untransformed_scale_std >= 0):
raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and ' raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and '
'its value should > 0') 'its value should > 0')

@ -61,9 +61,9 @@ class TransformToBNN:
""" """
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)): if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`') raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)): if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`') raise TypeError('The type of `bnn_factor` should be `int` or `float`')
net_with_loss = trainable_dnn.network net_with_loss = trainable_dnn.network

Loading…
Cancel
Save