modify the method for calculating kl loss

pull/7859/head
bingyaweng 4 years ago
parent 78f795971b
commit 11aa3f6f5f

@ -15,29 +15,12 @@
"""Generate WithLossCell suitable for BNN."""
from .conv_variational import _ConvVariational
from .dense_variational import _DenseVariational
from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss
from ...cell import Cell
__all__ = ['WithBNNLossCell']
class ClassWrap:
"""Decorator of WithBNNLossCell"""
def __init__(self, cls):
self._cls = cls
self.bnn_loss_file = None
self.__doc__ = cls.__doc__
self.__name__ = cls.__name__
self.__bases__ = cls.__bases__
def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor):
obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor)
bnn_with_loss = obj()
self.bnn_loss_file = obj.bnn_loss_file
return bnn_with_loss
@ClassWrap
class WithBNNLossCell:
class WithBNNLossCell(Cell):
r"""
Generate a suitable WithLossCell for BNN to wrap the bayesian network with loss function.
@ -68,6 +51,7 @@ class WithBNNLossCell:
"""
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
super(WithBNNLossCell, self).__init__(auto_prefix=False)
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if dnn_factor < 0:
@ -78,28 +62,36 @@ class WithBNNLossCell:
if bnn_factor < 0:
raise ValueError('The value of `bnn_factor` should >= 0')
self.backbone = backbone
self.loss_fn = loss_fn
self._backbone = backbone
self._loss_fn = loss_fn
self.dnn_factor = dnn_factor
self.bnn_factor = bnn_factor
self.bnn_loss_file = None
def _generate_loss_cell(self):
"""Generate WithBNNLossCell by ast."""
layer_count = self._kl_loss_count(self.backbone)
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
self.dnn_factor, self.bnn_factor)
return bnn_with_loss
def _kl_loss_count(self, net):
""" Calculate the number of Bayesian layers."""
count = 0
self.kl_loss = []
self._add_kl_loss(self._backbone)
def construct(self, x, label):
y_pred = self._backbone(x)
backbone_loss = self._loss_fn(y_pred, label)
kl_loss = 0
for i in range(len(self.kl_loss)):
kl_loss += self.kl_loss[i]()
loss = backbone_loss * self.dnn_factor + kl_loss * self.bnn_factor
return loss
def _add_kl_loss(self, net):
"""Collect kl loss of each Bayesian layer."""
for (_, layer) in net.name_cells().items():
if isinstance(layer, (_DenseVariational, _ConvVariational)):
count += 1
self.kl_loss.append(layer.compute_kl_loss)
else:
count += self._kl_loss_count(layer)
return count
self._add_kl_loss(layer)
@property
def backbone_network(self):
"""
Returns the backbone network.
def __call__(self):
return self._generate_loss_cell()
Returns:
Cell, the backbone network.
"""
return self._backbone

@ -1,19 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
bnn loss.
"""
from . import generate_kl_loss
from .generate_kl_loss import gain_bnn_with_loss

@ -1,89 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Gain bnn_with_loss by rewrite WithLossCell as WithBNNLossCell to suit for BNN model"""
import ast
import importlib
import os
import sys
import tempfile
import astunparse
import mindspore
class _CodeTransformer(ast.NodeTransformer):
"""
Add kl_loss computation by analyzing the python code structure with the help of the AST module.
Args:
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
"""
def __init__(self, layer_count):
self.layer_count = layer_count
def visit_FunctionDef(self, node):
"""visit function and add kl_loss computation."""
self.generic_visit(node)
if node.name == 'cal_kl_loss':
for i in range(self.layer_count):
func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())],
value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(),
right=ast.Call(func=ast.Name(id='self.kl_loss' + '[' + str(i) + ']',
ctx=ast.Load()),
args=[], keywords=[])))
node.body.insert(-1, func)
return node
def _generate_kl_loss_func(layer_count):
"""Rewrite WithLossCell as WithBNNLossCell to suit for BNN model."""
path = os.path.dirname(mindspore.__file__) + '/nn/probability/transforms/bnn_loss/withLossCell.py'
with open(path, 'r') as fp:
srclines = fp.readlines()
src = ''.join(srclines)
if src.startswith((' ', '\t')):
src = 'if 1:\n' + src
expr_ast = ast.parse(src, mode='exec')
transformer = _CodeTransformer(layer_count)
modify = transformer.visit(expr_ast)
modify = ast.fix_missing_locations(modify)
func = astunparse.unparse(modify)
return func
def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor):
"""
Gain bnn_with_loss, which wraps bnn network with loss function and kl loss of each bayesian layer.
Args:
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function.
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
"""
bnn_loss_func = _generate_kl_loss_func(layer_count)
path = os.path.dirname(mindspore.__file__)
bnn_loss_file = tempfile.NamedTemporaryFile(mode='w+t', suffix='.py', delete=True,
dir=path + '/nn/probability/transforms/bnn_loss')
bnn_loss_file.write(bnn_loss_func)
bnn_loss_file.seek(0)
sys.path.append(path + '/nn/probability/transforms/bnn_loss')
module_name = os.path.basename(bnn_loss_file.name)[0:-3]
bnn_loss_module = importlib.import_module(module_name, __package__)
bnn_with_loss = bnn_loss_module.WithBNNLossCell(backbone, loss_fn, dnn_factor, bnn_factor)
return bnn_with_loss, bnn_loss_file

@ -1,66 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Original WithBNNLossCell for ast to rewrite."""
import mindspore.nn as nn
from mindspore.nn.probability.bnn_layers.conv_variational import _ConvVariational
from mindspore.nn.probability.bnn_layers.dense_variational import _DenseVariational
class WithBNNLossCell(nn.Cell):
"""
Cell with loss function.
Wraps the network with loss function. This Cell accepts data, label, backbone_factor and kl_factor as inputs and
the computed loss will be returned.
"""
def __init__(self, backbone, loss_fn, backbone_factor=1, kl_factor=1):
super(WithBNNLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
self.backbone_factor = backbone_factor
self.kl_factor = kl_factor
self.kl_loss = []
self._add_kl_loss(self._backbone)
def construct(self, x, label):
y_pred = self._backbone(x)
backbone_loss = self._loss_fn(y_pred, label)
kl_loss = self.cal_kl_loss()
loss = backbone_loss*self.backbone_factor + kl_loss*self.kl_factor
return loss
def cal_kl_loss(self):
"""Calculate kl loss."""
loss = 0.0
return loss
def _add_kl_loss(self, net):
"""Collect kl loss of each Bayesian layer."""
for (_, layer) in net.name_cells().items():
if isinstance(layer, (_DenseVariational, _ConvVariational)):
self.kl_loss.append(layer.compute_kl_loss)
else:
self._add_kl_loss(layer)
@property
def backbone_network(self):
"""
Returns the backbone network.
Returns:
Cell, the backbone network.
"""
return self._backbone

@ -17,8 +17,8 @@ import mindspore.nn as nn
from ...wrap.cell_wrapper import TrainOneStepCell
from ....nn import optim
from ....nn import layer
from .bnn_loss.generate_kl_loss import gain_bnn_with_loss
from ...probability import bnn_layers
from ..bnn_layers.bnn_cell_wrapper import WithBNNLossCell
from ..bnn_layers.conv_variational import ConvReparam
from ..bnn_layers.dense_variational import DenseReparam
@ -77,7 +77,6 @@ class TransformToBNN:
self.loss_fn = getattr(net_with_loss, "_loss_fn")
self.dnn_factor = dnn_factor
self.bnn_factor = bnn_factor
self.bnn_loss_file = None
def transform_to_bnn_model(self,
get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias,
@ -120,15 +119,13 @@ class TransformToBNN:
if not add_conv_args:
add_conv_args = {}
layer_count = self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args,
add_conv_args)
self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args)
# rename layers of BNN model to prevent duplication of names
for value, param in self.backbone.parameters_and_names():
param.name = value
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
self.dnn_factor, self.bnn_factor)
bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor)
bnn_optimizer = self._create_optimizer_with_bnn_params()
train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer)
return train_bnn_network
@ -179,13 +176,11 @@ class TransformToBNN:
if not add_args:
add_args = {}
layer_count = self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args,
add_args)
self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, add_args)
for value, param in self.backbone.parameters_and_names():
param.name = value
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
self.dnn_factor, self.bnn_factor)
bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor)
bnn_optimizer = self._create_optimizer_with_bnn_params()
train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer)
@ -228,32 +223,25 @@ class TransformToBNN:
def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args):
"""Replace both dense layer and conv2d layer in DNN model to bayesian layers."""
count = 0
for name, cell in backbone.name_cells().items():
if isinstance(cell, nn.Dense):
dense_args = get_dense_args(cell)
new_layer = DenseReparam(**dense_args, **add_dense_args)
setattr(backbone, name, new_layer)
count += 1
elif isinstance(cell, nn.Conv2d):
conv_args = get_conv_args(cell)
new_layer = ConvReparam(**conv_args, **add_conv_args)
setattr(backbone, name, new_layer)
count += 1
else:
count += self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args,
add_conv_args)
return count
self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args,
add_conv_args)
def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args):
"""Convert a specific type of layers in DNN model to corresponding bayesian layers."""
count = 0
for name, cell in backbone.name_cells().items():
if isinstance(cell, dnn_layer):
args = get_args(cell)
new_layer = bnn_layer(**args, **add_args)
setattr(backbone, name, new_layer)
count += 1
else:
count += self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args)
return count
self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args)

Loading…
Cancel
Save