diff --git a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py index aedbd10ad6..3aef08133e 100644 --- a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +++ b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py @@ -26,8 +26,8 @@ class ClassWrap: self._cls = cls self.bnn_loss_file = None - def __call__(self, backbone, loss_fn, backbone_factor, kl_factor): - obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor) + def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor): + obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor) bnn_with_loss = obj() self.bnn_loss_file = obj.bnn_loss_file return bnn_with_loss @@ -65,6 +65,11 @@ class WithBNNLossCell: """ def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): + if not isinstance(dnn_factor, (int, float)): + raise TypeError('The type of `dnn_factor` should be `int` or `float`') + if not isinstance(bnn_factor, (int, float)): + raise TypeError('The type of `bnn_factor` should be `int` or `float`') + self.backbone = backbone self.loss_fn = loss_fn self.dnn_factor = dnn_factor diff --git a/mindspore/nn/probability/bnn_layers/conv_variational.py b/mindspore/nn/probability/bnn_layers/conv_variational.py index 65b5ec0c38..4434fee9e0 100644 --- a/mindspore/nn/probability/bnn_layers/conv_variational.py +++ b/mindspore/nn/probability/bnn_layers/conv_variational.py @@ -79,20 +79,40 @@ class _ConvVariational(_Conv): self.weight.requires_grad = False if isinstance(weight_prior_fn, Cell): + if weight_prior_fn.__class__.__name__ != 'NormalPrior': + raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') self.weight_prior = weight_prior_fn else: + if weight_prior_fn.__name__ != 'NormalPrior': + raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') self.weight_prior = weight_prior_fn() + if isinstance(weight_posterior_fn, Cell): + if weight_posterior_fn.__class__.__name__ != 'NormalPosterior': + raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') + else: + if weight_posterior_fn.__name__ != 'NormalPosterior': + raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight') if self.has_bias: self.bias.requires_grad = False if isinstance(bias_prior_fn, Cell): + if bias_prior_fn.__class__.__name__ != 'NormalPrior': + raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') self.bias_prior = bias_prior_fn else: + if bias_prior_fn.__name__ != 'NormalPrior': + raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') self.bias_prior = bias_prior_fn() + if isinstance(bias_posterior_fn, Cell): + if bias_posterior_fn.__class__.__name__ != 'NormalPosterior': + raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') + else: + if bias_posterior_fn.__name__ != 'NormalPosterior': + raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') # mindspore operations diff --git a/mindspore/nn/probability/bnn_layers/dense_variational.py b/mindspore/nn/probability/bnn_layers/dense_variational.py index d5a43944d2..4e28c125a8 100644 --- a/mindspore/nn/probability/bnn_layers/dense_variational.py +++ b/mindspore/nn/probability/bnn_layers/dense_variational.py @@ -43,18 +43,38 @@ class _DenseVariational(Cell): self.has_bias = check_bool(has_bias) if isinstance(weight_prior_fn, Cell): + if weight_prior_fn.__class__.__name__ != 'NormalPrior': + raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') self.weight_prior = weight_prior_fn else: + if weight_prior_fn.__name__ != 'NormalPrior': + raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') self.weight_prior = weight_prior_fn() + if isinstance(weight_posterior_fn, Cell): + if weight_posterior_fn.__class__.__name__ != 'NormalPosterior': + raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') + else: + if weight_posterior_fn.__name__ != 'NormalPosterior': + raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight') if self.has_bias: if isinstance(bias_prior_fn, Cell): + if bias_prior_fn.__class__.__name__ != 'NormalPrior': + raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') self.bias_prior = bias_prior_fn else: + if bias_prior_fn.__name__ != 'NormalPrior': + raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') self.bias_prior = bias_prior_fn() + if isinstance(bias_posterior_fn, Cell): + if bias_posterior_fn.__class__.__name__ != 'NormalPosterior': + raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') + else: + if bias_posterior_fn.__name__ != 'NormalPosterior': + raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') self.activation = activation diff --git a/mindspore/nn/probability/bnn_layers/layer_distribution.py b/mindspore/nn/probability/bnn_layers/layer_distribution.py index 2c4acec8d7..f36e1dbbe5 100644 --- a/mindspore/nn/probability/bnn_layers/layer_distribution.py +++ b/mindspore/nn/probability/bnn_layers/layer_distribution.py @@ -75,7 +75,18 @@ class NormalPosterior(Cell): untransformed_scale_std=0.1): super(NormalPosterior, self).__init__() if not isinstance(name, str): - raise ValueError('The type of `name` should be `str`') + raise TypeError('The type of `name` should be `str`') + + if not isinstance(shape, (tuple, list)): + raise TypeError('The type of `shape` should be `tuple` or `list`') + + if not (np.array(shape) > 0).all(): + raise ValueError('Negative dimensions are not allowed') + + if not (np.array(loc_std) >= 0).all(): + raise ValueError('The value of `loc_std` < 0') + if not (np.array(untransformed_scale_std) >= 0).all(): + raise ValueError('The value of `untransformed_scale_std` < 0') self.mean = Parameter( Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') diff --git a/mindspore/nn/probability/transforms/__init__.py b/mindspore/nn/probability/transforms/__init__.py index a42f233e92..7b8524359c 100644 --- a/mindspore/nn/probability/transforms/__init__.py +++ b/mindspore/nn/probability/transforms/__init__.py @@ -15,7 +15,7 @@ """ Transforms. -The high-level components used to transform model between DNN and DNN. +The high-level components used to transform model between DNN and BNN. """ from . import transform_bnn from .transform_bnn import TransformToBNN diff --git a/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py b/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py index baf2a61f4d..14176e1d20 100644 --- a/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py +++ b/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py @@ -54,3 +54,13 @@ class WithBNNLossCell(nn.Cell): self.kl_loss.append(layer.compute_kl_loss) else: self._add_kl_loss(layer) + + @property + def backbone_network(self): + """ + Returns the backbone network. + + Returns: + Cell, the backbone network. + """ + return self._backbone diff --git a/mindspore/nn/probability/transforms/transform_bnn.py b/mindspore/nn/probability/transforms/transform_bnn.py index debbbc7179..3853c48b23 100644 --- a/mindspore/nn/probability/transforms/transform_bnn.py +++ b/mindspore/nn/probability/transforms/transform_bnn.py @@ -61,6 +61,11 @@ class TransformToBNN: """ def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): + if not isinstance(dnn_factor, (int, float)): + raise TypeError('The type of `dnn_factor` should be `int` or `float`') + if not isinstance(bnn_factor, (int, float)): + raise TypeError('The type of `bnn_factor` should be `int` or `float`') + net_with_loss = trainable_dnn.network self.optimizer = trainable_dnn.optimizer self.backbone = net_with_loss.backbone_network @@ -88,8 +93,10 @@ class TransformToBNN: get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. - add_dense_args (dict): The new arguments added to BNN full connection layer. Default: {}. - add_conv_args (dict): The new arguments added to BNN convolutional layer. Default: {}. + add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in + `add_dense_args` should not duplicate arguments in `get_dense_args`. Default: {}. + add_conv_args (dict): The new arguments added to BNN convolutional layer. Note that the arguments in + `add_conv_args` should not duplicate arguments in `get_conv_args`. Default: {}. Returns: Cell, a trainable BNN model wrapped by TrainOneStepCell. @@ -131,7 +138,8 @@ class TransformToBNN: bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are DenseReparameterization, ConvReparameterization. get_args (dict): The arguments gotten from the DNN layer. Default: None. - add_args (dict): The new arguments added to BNN layer. Default: None. + add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not + duplicate arguments in `get_args`. Default: None. Returns: Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the