!4881 Fix param check

Merge pull request !4881 from byweng/fix_param_check
pull/4881/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f69b4e03b7

@ -65,9 +65,9 @@ class WithBNNLossCell:
"""
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
self.backbone = backbone

@ -173,8 +173,7 @@ class ConvReparam(_ConvVariational):
r"""
Convolutional variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.

@ -132,8 +132,7 @@ class DenseReparam(_DenseVariational):
r"""
Dense variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
Applies dense-connected layer for the input. This layer implements the operation as:

@ -78,16 +78,17 @@ class NormalPosterior(Cell):
if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`')
if not isinstance(loc_mean, (int, float)):
if isinstance(loc_mean, bool) or not isinstance(loc_mean, (int, float)):
raise TypeError('The type of `loc_mean` should be `int` or `float`')
if not isinstance(untransformed_scale_mean, (int, float)):
if isinstance(untransformed_scale_mean, bool) or not isinstance(untransformed_scale_mean, (int, float)):
raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`')
if not (isinstance(loc_std, (int, float)) and loc_std >= 0):
if isinstance(loc_std, bool) or not (isinstance(loc_std, (int, float)) and loc_std >= 0):
raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0')
if not (isinstance(untransformed_scale_std, (int, float)) and untransformed_scale_std >= 0):
if isinstance(loc_std, bool) or not (isinstance(untransformed_scale_std, (int, float)) and
untransformed_scale_std >= 0):
raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and '
'its value should > 0')

@ -61,9 +61,9 @@ class TransformToBNN:
"""
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
net_with_loss = trainable_dnn.network

Loading…
Cancel
Save