!4717 Fix error in bnn_layer and transforms

Merge pull request !4717 from byweng/fix_error
pull/4717/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b283a1fe10

@ -26,8 +26,8 @@ class ClassWrap:
self._cls = cls
self.bnn_loss_file = None
def __call__(self, backbone, loss_fn, backbone_factor, kl_factor):
obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor)
def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor):
obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor)
bnn_with_loss = obj()
self.bnn_loss_file = obj.bnn_loss_file
return bnn_with_loss
@ -65,6 +65,11 @@ class WithBNNLossCell:
"""
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
self.backbone = backbone
self.loss_fn = loss_fn
self.dnn_factor = dnn_factor

@ -79,20 +79,40 @@ class _ConvVariational(_Conv):
self.weight.requires_grad = False
if isinstance(weight_prior_fn, Cell):
if weight_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn
else:
if weight_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn()
if isinstance(weight_posterior_fn, Cell):
if weight_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
else:
if weight_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight')
if self.has_bias:
self.bias.requires_grad = False
if isinstance(bias_prior_fn, Cell):
if bias_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn
else:
if bias_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn()
if isinstance(bias_posterior_fn, Cell):
if bias_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
else:
if bias_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
# mindspore operations

@ -43,18 +43,38 @@ class _DenseVariational(Cell):
self.has_bias = check_bool(has_bias)
if isinstance(weight_prior_fn, Cell):
if weight_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn
else:
if weight_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`')
self.weight_prior = weight_prior_fn()
if isinstance(weight_posterior_fn, Cell):
if weight_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
else:
if weight_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`')
self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
if self.has_bias:
if isinstance(bias_prior_fn, Cell):
if bias_prior_fn.__class__.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn
else:
if bias_prior_fn.__name__ != 'NormalPrior':
raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`')
self.bias_prior = bias_prior_fn()
if isinstance(bias_posterior_fn, Cell):
if bias_posterior_fn.__class__.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
else:
if bias_posterior_fn.__name__ != 'NormalPosterior':
raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`')
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
self.activation = activation

@ -75,7 +75,18 @@ class NormalPosterior(Cell):
untransformed_scale_std=0.1):
super(NormalPosterior, self).__init__()
if not isinstance(name, str):
raise ValueError('The type of `name` should be `str`')
raise TypeError('The type of `name` should be `str`')
if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`')
if not (np.array(shape) > 0).all():
raise ValueError('Negative dimensions are not allowed')
if not (np.array(loc_std) >= 0).all():
raise ValueError('The value of `loc_std` < 0')
if not (np.array(untransformed_scale_std) >= 0).all():
raise ValueError('The value of `untransformed_scale_std` < 0')
self.mean = Parameter(
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')

@ -15,7 +15,7 @@
"""
Transforms.
The high-level components used to transform model between DNN and DNN.
The high-level components used to transform model between DNN and BNN.
"""
from . import transform_bnn
from .transform_bnn import TransformToBNN

@ -54,3 +54,13 @@ class WithBNNLossCell(nn.Cell):
self.kl_loss.append(layer.compute_kl_loss)
else:
self._add_kl_loss(layer)
@property
def backbone_network(self):
"""
Returns the backbone network.
Returns:
Cell, the backbone network.
"""
return self._backbone

@ -61,6 +61,11 @@ class TransformToBNN:
"""
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
net_with_loss = trainable_dnn.network
self.optimizer = trainable_dnn.optimizer
self.backbone = net_with_loss.backbone_network
@ -88,8 +93,10 @@ class TransformToBNN:
get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
add_dense_args (dict): The new arguments added to BNN full connection layer. Default: {}.
add_conv_args (dict): The new arguments added to BNN convolutional layer. Default: {}.
add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in
`add_dense_args` should not duplicate arguments in `get_dense_args`. Default: {}.
add_conv_args (dict): The new arguments added to BNN convolutional layer. Note that the arguments in
`add_conv_args` should not duplicate arguments in `get_conv_args`. Default: {}.
Returns:
Cell, a trainable BNN model wrapped by TrainOneStepCell.
@ -131,7 +138,8 @@ class TransformToBNN:
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
DenseReparameterization, ConvReparameterization.
get_args (dict): The arguments gotten from the DNN layer. Default: None.
add_args (dict): The new arguments added to BNN layer. Default: None.
add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not
duplicate arguments in `get_args`. Default: None.
Returns:
Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the

Loading…
Cancel
Save