parent
fb2f888ec8
commit
61dbb1b17c
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Bayesian Layer.
|
||||||
|
|
||||||
|
The high-level components(Cells) used to construct the bayesian neural network.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper
|
||||||
|
from .conv_variational import ConvReparam
|
||||||
|
from .dense_variational import DenseReparam
|
||||||
|
from .layer_distribution import NormalPrior, NormalPosterior
|
||||||
|
from .bnn_cell_wrapper import WithBNNLossCell
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
__all__.extend(conv_variational.__all__)
|
||||||
|
__all__.extend(dense_variational.__all__)
|
||||||
|
__all__.extend(layer_distribution.__all__)
|
||||||
|
__all__.extend(bnn_cell_wrapper.__all__)
|
||||||
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Generate WithLossCell suitable for BNN."""
|
||||||
|
from .conv_variational import _ConvVariational
|
||||||
|
from .dense_variational import _DenseVariational
|
||||||
|
from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss
|
||||||
|
|
||||||
|
__all__ = ['WithBNNLossCell']
|
||||||
|
|
||||||
|
|
||||||
|
class ClassWrap:
|
||||||
|
"""Decorator of WithBNNLossCell"""
|
||||||
|
def __init__(self, cls):
|
||||||
|
self._cls = cls
|
||||||
|
self.bnn_loss_file = None
|
||||||
|
|
||||||
|
def __call__(self, backbone, loss_fn, backbone_factor, kl_factor):
|
||||||
|
obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor)
|
||||||
|
bnn_with_loss = obj()
|
||||||
|
self.bnn_loss_file = obj.bnn_loss_file
|
||||||
|
return bnn_with_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ClassWrap
|
||||||
|
class WithBNNLossCell:
|
||||||
|
r"""
|
||||||
|
Generate WithLossCell suitable for BNN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone (Cell): The target network.
|
||||||
|
loss_fn (Cell): The loss function used to compute loss.
|
||||||
|
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by loss functin. Default: 1.
|
||||||
|
bnn_factor(int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||||
|
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a scalar tensor with shape :math:`()`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
|
>>> net_with_criterion_object = WithBNNLossCell(net, loss_fn)
|
||||||
|
>>> net_with_criterion = net_with_criterion_object()
|
||||||
|
>>>
|
||||||
|
>>> batch_size = 2
|
||||||
|
>>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01)
|
||||||
|
>>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32))
|
||||||
|
>>>
|
||||||
|
>>> net_with_criterion(data, label)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
|
||||||
|
self.backbone = backbone
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.dnn_factor = dnn_factor
|
||||||
|
self.bnn_factor = bnn_factor
|
||||||
|
self.bnn_loss_file = None
|
||||||
|
|
||||||
|
def _generate_loss_cell(self):
|
||||||
|
"""Generate WithBNNLossCell by ast."""
|
||||||
|
layer_count = self._kl_loss_count(self.backbone)
|
||||||
|
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
|
||||||
|
self.dnn_factor, self.bnn_factor)
|
||||||
|
return bnn_with_loss
|
||||||
|
|
||||||
|
def _kl_loss_count(self, net):
|
||||||
|
""" Calculate the number of Bayesian layers."""
|
||||||
|
count = 0
|
||||||
|
for (_, layer) in net.name_cells().items():
|
||||||
|
if isinstance(layer, (_DenseVariational, _ConvVariational)):
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
count += self._kl_loss_count(layer)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self._generate_loss_cell()
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""dense_variational"""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore._checkparam import check_int_positive, check_bool
|
||||||
|
from ...cell import Cell
|
||||||
|
from ...layer.activation import get_activation
|
||||||
|
from .layer_distribution import NormalPrior, NormalPosterior
|
||||||
|
|
||||||
|
__all__ = ['DenseReparam']
|
||||||
|
|
||||||
|
|
||||||
|
class _DenseVariational(Cell):
|
||||||
|
"""
|
||||||
|
Base class for all dense variational layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
activation=None,
|
||||||
|
has_bias=True,
|
||||||
|
weight_prior_fn=NormalPrior,
|
||||||
|
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
|
||||||
|
bias_prior_fn=NormalPrior,
|
||||||
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||||
|
super(_DenseVariational, self).__init__()
|
||||||
|
self.in_channels = check_int_positive(in_channels)
|
||||||
|
self.out_channels = check_int_positive(out_channels)
|
||||||
|
self.has_bias = check_bool(has_bias)
|
||||||
|
|
||||||
|
if isinstance(weight_prior_fn, Cell):
|
||||||
|
self.weight_prior = weight_prior_fn
|
||||||
|
else:
|
||||||
|
self.weight_prior = weight_prior_fn()
|
||||||
|
|
||||||
|
self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
if isinstance(bias_prior_fn, Cell):
|
||||||
|
self.bias_prior = bias_prior_fn
|
||||||
|
else:
|
||||||
|
self.bias_prior = bias_prior_fn()
|
||||||
|
|
||||||
|
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
if isinstance(self.activation, str):
|
||||||
|
self.activation = get_activation(activation)
|
||||||
|
self.activation_flag = self.activation is not None
|
||||||
|
|
||||||
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
self.sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
outputs = self._apply_variational_weight(x)
|
||||||
|
if self.has_bias:
|
||||||
|
outputs = self._apply_variational_bias(outputs)
|
||||||
|
if self.activation_flag:
|
||||||
|
outputs = self.activation(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
str_info = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
|
||||||
|
.format(self.in_channels, self.out_channels, self.weight_posterior.mean,
|
||||||
|
self.weight_posterior.untransformed_std, self.has_bias)
|
||||||
|
if self.has_bias:
|
||||||
|
str_info = str_info + ', bias_mean={}, bias_std={}' \
|
||||||
|
.format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
|
||||||
|
|
||||||
|
if self.activation_flag:
|
||||||
|
str_info = str_info + ', activation={}'.format(self.activation)
|
||||||
|
return str_info
|
||||||
|
|
||||||
|
def _apply_variational_bias(self, inputs):
|
||||||
|
bias_posterior_tensor = self.bias_posterior("sample")
|
||||||
|
return self.bias_add(inputs, bias_posterior_tensor)
|
||||||
|
|
||||||
|
def compute_kl_loss(self):
|
||||||
|
"""Compute kl loss."""
|
||||||
|
weight_post_mean = self.weight_posterior("mean")
|
||||||
|
weight_post_sd = self.weight_posterior("sd")
|
||||||
|
|
||||||
|
kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd)
|
||||||
|
kl_loss = self.sum(kl)
|
||||||
|
if self.has_bias:
|
||||||
|
bias_post_mean = self.bias_posterior("mean")
|
||||||
|
bias_post_sd = self.bias_posterior("sd")
|
||||||
|
|
||||||
|
kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd)
|
||||||
|
kl = self.sum(kl)
|
||||||
|
kl_loss += kl
|
||||||
|
return kl_loss
|
||||||
|
|
||||||
|
|
||||||
|
class DenseReparam(_DenseVariational):
|
||||||
|
r"""
|
||||||
|
Dense variational layers with Reparameterization.
|
||||||
|
|
||||||
|
See more details in paper `Auto-Encoding Variational Bayes
|
||||||
|
<https://arxiv.org/abs/1312.6114>`
|
||||||
|
|
||||||
|
Applies dense-connected layer for the input. This layer implements the operation as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
||||||
|
|
||||||
|
where :math:`\text{activation}` is the activation function passed as the activation
|
||||||
|
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
|
||||||
|
data type as the inputs created by the layer, :math:`\text{weight}` is a weight
|
||||||
|
matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a
|
||||||
|
bias vector with the same data type as the inputs created by the layer (only if
|
||||||
|
has_bias is True). The bias vector is sampling from posterior distribution of
|
||||||
|
:math:`\text{bias}`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel.
|
||||||
|
out_channels (int): The number of output channel .
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||||
|
weight_prior_fn: prior distribution for weight.
|
||||||
|
It should return a mindspore distribution instance.
|
||||||
|
Default: NormalPrior. (which creates an instance of standard
|
||||||
|
normal distribution).
|
||||||
|
weight_posterior_fn: posterior distribution for sampling weight.
|
||||||
|
It should be a function handle which returns a mindspore
|
||||||
|
distribution instance.
|
||||||
|
Default: NormalPosterior.
|
||||||
|
bias_prior_fn: prior distribution for bias vector. It should return
|
||||||
|
a mindspore distribution.
|
||||||
|
Default: NormalPrior(which creates an instance of standard
|
||||||
|
normal distribution).
|
||||||
|
bias_posterior_fn: posterior distribution for sampling bias vector.
|
||||||
|
It should be a function handle which returns a mindspore
|
||||||
|
distribution instance.
|
||||||
|
Default: NormalPosterior.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, out\_channels)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = DenseReparam(3, 4)
|
||||||
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||||
|
>>> net(input)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
activation=None,
|
||||||
|
has_bias=True,
|
||||||
|
weight_prior_fn=NormalPrior,
|
||||||
|
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
|
||||||
|
bias_prior_fn=NormalPrior,
|
||||||
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||||
|
super(DenseReparam, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
activation=activation,
|
||||||
|
has_bias=has_bias,
|
||||||
|
weight_prior_fn=weight_prior_fn,
|
||||||
|
weight_posterior_fn=weight_posterior_fn,
|
||||||
|
bias_prior_fn=bias_prior_fn,
|
||||||
|
bias_posterior_fn=bias_posterior_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_variational_weight(self, inputs):
|
||||||
|
weight_posterior_tensor = self.weight_posterior("sample")
|
||||||
|
outputs = self.matmul(inputs, weight_posterior_tensor)
|
||||||
|
return outputs
|
||||||
@ -0,0 +1,96 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Initialize normal distributions"""
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from ...cell import Cell
|
||||||
|
from ..distribution.normal import Normal
|
||||||
|
|
||||||
|
__all__ = ['NormalPrior', 'NormalPosterior']
|
||||||
|
|
||||||
|
|
||||||
|
class NormalPrior(Cell):
|
||||||
|
r"""
|
||||||
|
To initialize a normal distribution of mean 0 and standard deviation 0.1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||||
|
Default: mindspore.float32.
|
||||||
|
mean (int, float): Mean of normal distribution.
|
||||||
|
std (int, float): Standard deviation of normal distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cell, a normal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self, dtype=mstype.float32, mean=0, std=0.1):
|
||||||
|
super(NormalPrior, self).__init__()
|
||||||
|
self.normal = Normal(mean, std, dtype=dtype)
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return self.normal(*inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalPosterior(Cell):
|
||||||
|
r"""
|
||||||
|
Build Normal distributions with trainable parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name prepended to trainable parameter.
|
||||||
|
shape (list): Shape of the mean and standard deviation.
|
||||||
|
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||||
|
Default: mindspore.float32.
|
||||||
|
loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0.
|
||||||
|
loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters.
|
||||||
|
Default: 0.1.
|
||||||
|
untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable
|
||||||
|
parameters. Default: -5.
|
||||||
|
untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize
|
||||||
|
trainable parameters. Default: 0.1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cell, a normal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
loc_mean=0,
|
||||||
|
loc_std=0.1,
|
||||||
|
untransformed_scale_mean=-5,
|
||||||
|
untransformed_scale_std=0.1):
|
||||||
|
super(NormalPosterior, self).__init__()
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError('The type of `name` should be `str`')
|
||||||
|
|
||||||
|
self.mean = Parameter(
|
||||||
|
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')
|
||||||
|
|
||||||
|
self.untransformed_std = Parameter(
|
||||||
|
Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype),
|
||||||
|
name=name + '_untransformed_std')
|
||||||
|
|
||||||
|
self.normal = Normal()
|
||||||
|
|
||||||
|
def std_trans(self, std_pre):
|
||||||
|
"""Transform std_pre to prevent its value being zero."""
|
||||||
|
std = 1e-6 + P.Log()(P.Exp()(std_pre) + 1)
|
||||||
|
return std
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
std = self.std_trans(self.untransformed_std)
|
||||||
|
return self.normal(*inputs, mean=self.mean, sd=std)
|
||||||
Loading…
Reference in new issue