add bnn_layers to nn.probability

pull/4403/head
bingyaweng 5 years ago
parent fb2f888ec8
commit 61dbb1b17c

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Bayesian Layer.
The high-level components(Cells) used to construct the bayesian neural network.
"""
from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper
from .conv_variational import ConvReparam
from .dense_variational import DenseReparam
from .layer_distribution import NormalPrior, NormalPosterior
from .bnn_cell_wrapper import WithBNNLossCell
__all__ = []
__all__.extend(conv_variational.__all__)
__all__.extend(dense_variational.__all__)
__all__.extend(layer_distribution.__all__)
__all__.extend(bnn_cell_wrapper.__all__)

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate WithLossCell suitable for BNN."""
from .conv_variational import _ConvVariational
from .dense_variational import _DenseVariational
from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss
__all__ = ['WithBNNLossCell']
class ClassWrap:
"""Decorator of WithBNNLossCell"""
def __init__(self, cls):
self._cls = cls
self.bnn_loss_file = None
def __call__(self, backbone, loss_fn, backbone_factor, kl_factor):
obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor)
bnn_with_loss = obj()
self.bnn_loss_file = obj.bnn_loss_file
return bnn_with_loss
@ClassWrap
class WithBNNLossCell:
r"""
Generate WithLossCell suitable for BNN.
Args:
backbone (Cell): The target network.
loss_fn (Cell): The loss function used to compute loss.
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by loss functin. Default: 1.
bnn_factor(int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tensor, a scalar tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_with_criterion_object = WithBNNLossCell(net, loss_fn)
>>> net_with_criterion = net_with_criterion_object()
>>>
>>> batch_size = 2
>>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01)
>>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32))
>>>
>>> net_with_criterion(data, label)
"""
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
self.backbone = backbone
self.loss_fn = loss_fn
self.dnn_factor = dnn_factor
self.bnn_factor = bnn_factor
self.bnn_loss_file = None
def _generate_loss_cell(self):
"""Generate WithBNNLossCell by ast."""
layer_count = self._kl_loss_count(self.backbone)
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
self.dnn_factor, self.bnn_factor)
return bnn_with_loss
def _kl_loss_count(self, net):
""" Calculate the number of Bayesian layers."""
count = 0
for (_, layer) in net.name_cells().items():
if isinstance(layer, (_DenseVariational, _ConvVariational)):
count += 1
else:
count += self._kl_loss_count(layer)
return count
def __call__(self):
return self._generate_loss_cell()

File diff suppressed because it is too large Load Diff

@ -0,0 +1,188 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dense_variational"""
from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive, check_bool
from ...cell import Cell
from ...layer.activation import get_activation
from .layer_distribution import NormalPrior, NormalPosterior
__all__ = ['DenseReparam']
class _DenseVariational(Cell):
"""
Base class for all dense variational layers.
"""
def __init__(
self,
in_channels,
out_channels,
activation=None,
has_bias=True,
weight_prior_fn=NormalPrior,
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
bias_prior_fn=NormalPrior,
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
super(_DenseVariational, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
if isinstance(weight_prior_fn, Cell):
self.weight_prior = weight_prior_fn
else:
self.weight_prior = weight_prior_fn()
self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
if self.has_bias:
if isinstance(bias_prior_fn, Cell):
self.bias_prior = bias_prior_fn
else:
self.bias_prior = bias_prior_fn()
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
self.activation = activation
if isinstance(self.activation, str):
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.sum = P.ReduceSum()
def construct(self, x):
outputs = self._apply_variational_weight(x)
if self.has_bias:
outputs = self._apply_variational_bias(outputs)
if self.activation_flag:
outputs = self.activation(outputs)
return outputs
def extend_repr(self):
str_info = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight_posterior.mean,
self.weight_posterior.untransformed_std, self.has_bias)
if self.has_bias:
str_info = str_info + ', bias_mean={}, bias_std={}' \
.format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)
return str_info
def _apply_variational_bias(self, inputs):
bias_posterior_tensor = self.bias_posterior("sample")
return self.bias_add(inputs, bias_posterior_tensor)
def compute_kl_loss(self):
"""Compute kl loss."""
weight_post_mean = self.weight_posterior("mean")
weight_post_sd = self.weight_posterior("sd")
kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd)
kl_loss = self.sum(kl)
if self.has_bias:
bias_post_mean = self.bias_posterior("mean")
bias_post_sd = self.bias_posterior("sd")
kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd)
kl = self.sum(kl)
kl_loss += kl
return kl_loss
class DenseReparam(_DenseVariational):
r"""
Dense variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
Applies dense-connected layer for the input. This layer implements the operation as:
.. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
data type as the inputs created by the layer, :math:`\text{weight}` is a weight
matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a
bias vector with the same data type as the inputs created by the layer (only if
has_bias is True). The bias vector is sampling from posterior distribution of
:math:`\text{bias}`.
Args:
in_channels (int): The number of input channel.
out_channels (int): The number of output channel .
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
weight_prior_fn: prior distribution for weight.
It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard
normal distribution).
weight_posterior_fn: posterior distribution for sampling weight.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution.
Default: NormalPrior(which creates an instance of standard
normal distribution).
bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = DenseReparam(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(
self,
in_channels,
out_channels,
activation=None,
has_bias=True,
weight_prior_fn=NormalPrior,
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
bias_prior_fn=NormalPrior,
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
super(DenseReparam, self).__init__(
in_channels,
out_channels,
activation=activation,
has_bias=has_bias,
weight_prior_fn=weight_prior_fn,
weight_posterior_fn=weight_posterior_fn,
bias_prior_fn=bias_prior_fn,
bias_posterior_fn=bias_posterior_fn
)
def _apply_variational_weight(self, inputs):
weight_posterior_tensor = self.weight_posterior("sample")
outputs = self.matmul(inputs, weight_posterior_tensor)
return outputs

@ -0,0 +1,96 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Initialize normal distributions"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from ...cell import Cell
from ..distribution.normal import Normal
__all__ = ['NormalPrior', 'NormalPosterior']
class NormalPrior(Cell):
r"""
To initialize a normal distribution of mean 0 and standard deviation 0.1.
Args:
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
mean (int, float): Mean of normal distribution.
std (int, float): Standard deviation of normal distribution.
Returns:
Cell, a normal distribution.
"""
def __init__(self, dtype=mstype.float32, mean=0, std=0.1):
super(NormalPrior, self).__init__()
self.normal = Normal(mean, std, dtype=dtype)
def construct(self, *inputs):
return self.normal(*inputs)
class NormalPosterior(Cell):
r"""
Build Normal distributions with trainable parameters.
Args:
name (str): Name prepended to trainable parameter.
shape (list): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0.
loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters.
Default: 0.1.
untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable
parameters. Default: -5.
untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize
trainable parameters. Default: 0.1.
Returns:
Cell, a normal distribution.
"""
def __init__(self,
name,
shape,
dtype=mstype.float32,
loc_mean=0,
loc_std=0.1,
untransformed_scale_mean=-5,
untransformed_scale_std=0.1):
super(NormalPosterior, self).__init__()
if not isinstance(name, str):
raise ValueError('The type of `name` should be `str`')
self.mean = Parameter(
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')
self.untransformed_std = Parameter(
Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype),
name=name + '_untransformed_std')
self.normal = Normal()
def std_trans(self, std_pre):
"""Transform std_pre to prevent its value being zero."""
std = 1e-6 + P.Log()(P.Exp()(std_pre) + 1)
return std
def construct(self, *inputs):
std = self.std_trans(self.untransformed_std)
return self.normal(*inputs, mean=self.mean, sd=std)

@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
import mindspore.nn as nn
import mindspore.nn.probability as msp
def cast_to_tensor(t, hint_dtype=mstype.float32):
"""
@ -84,7 +85,7 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded.
"""
for value in params.values():
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch
if isinstance(value, Parameter):
return False
@ -109,7 +110,7 @@ def calc_broadcast_shape_from_param(params):
"""
broadcast_shape = []
for value in params.values():
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].broadcast_shape
if isinstance(value, (str, type(params['dtype']))):
continue

@ -36,7 +36,7 @@ class _CodeTransformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
"""visit function and add kl_loss computation."""
self.generic_visit(node)
if node.name == 'compute_kl_loss':
if node.name == 'cal_kl_loss':
for i in range(self.layer_count):
func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())],
value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(),
@ -71,7 +71,7 @@ def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor):
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function.
dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function.
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
"""
bnn_loss_func = _generate_kl_loss_func(layer_count)

@ -14,3 +14,4 @@ opencv-python >= 4.1.2.30 # for ut test
sklearn >= 0.0 # for st test
pandas >= 1.0.2 # for ut test
bs4
astunparse

@ -92,7 +92,8 @@ required_package = [
'easydict >= 1.9',
'sympy >= 1.4',
'cffi >= 1.13.2',
'decorator >= 4.4.0'
'decorator >= 4.4.0',
'astunparse >= 1.6.3'
]
package_data = {

Loading…
Cancel
Save