From 1b1ad52e7c7cef532edb79f53ffcc885c0cc996b Mon Sep 17 00:00:00 2001 From: bingyaweng Date: Tue, 11 Aug 2020 15:45:09 +0800 Subject: [PATCH] add transforms to nn.probability --- .../nn/probability/transforms/__init__.py | 24 ++ .../transforms/bnn_loss/__init__.py | 19 ++ .../transforms/bnn_loss/generate_kl_loss.py | 89 +++++++ .../transforms/bnn_loss/withLossCell.py | 56 ++++ .../probability/transforms/transform_bnn.py | 246 ++++++++++++++++++ 5 files changed, 434 insertions(+) create mode 100644 mindspore/nn/probability/transforms/__init__.py create mode 100644 mindspore/nn/probability/transforms/bnn_loss/__init__.py create mode 100644 mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py create mode 100644 mindspore/nn/probability/transforms/bnn_loss/withLossCell.py create mode 100644 mindspore/nn/probability/transforms/transform_bnn.py diff --git a/mindspore/nn/probability/transforms/__init__.py b/mindspore/nn/probability/transforms/__init__.py new file mode 100644 index 0000000000..a42f233e92 --- /dev/null +++ b/mindspore/nn/probability/transforms/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Transforms. + +The high-level components used to transform model between DNN and DNN. +""" +from . import transform_bnn +from .transform_bnn import TransformToBNN + +__all__ = [] +__all__.extend(transform_bnn.__all__) diff --git a/mindspore/nn/probability/transforms/bnn_loss/__init__.py b/mindspore/nn/probability/transforms/bnn_loss/__init__.py new file mode 100644 index 0000000000..c10f1a4578 --- /dev/null +++ b/mindspore/nn/probability/transforms/bnn_loss/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +bnn loss. +""" +from . import generate_kl_loss +from .generate_kl_loss import gain_bnn_with_loss diff --git a/mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py b/mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py new file mode 100644 index 0000000000..bd516d7129 --- /dev/null +++ b/mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Gain bnn_with_loss by rewrite WithLossCell as WithBNNLossCell to suit for BNN model""" +import ast +import importlib +import os +import sys +import tempfile +import astunparse +import mindspore + + +class _CodeTransformer(ast.NodeTransformer): + """ + Add kl_loss computation by analyzing the python code structure with the help of the AST module. + + Args: + layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. + """ + + def __init__(self, layer_count): + self.layer_count = layer_count + + def visit_FunctionDef(self, node): + """visit function and add kl_loss computation.""" + self.generic_visit(node) + if node.name == 'compute_kl_loss': + for i in range(self.layer_count): + func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())], + value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(), + right=ast.Call(func=ast.Name(id='self.kl_loss' + '[' + str(i) + ']', + ctx=ast.Load()), + args=[], keywords=[]))) + node.body.insert(-1, func) + return node + + +def _generate_kl_loss_func(layer_count): + """Rewrite WithLossCell as WithBNNLossCell to suit for BNN model.""" + path = os.path.dirname(mindspore.__file__) + '/nn/probability/transforms/bnn_loss/withLossCell.py' + with open(path, 'r') as fp: + srclines = fp.readlines() + src = ''.join(srclines) + if src.startswith((' ', '\t')): + src = 'if 1:\n' + src + expr_ast = ast.parse(src, mode='exec') + transformer = _CodeTransformer(layer_count) + modify = transformer.visit(expr_ast) + modify = ast.fix_missing_locations(modify) + func = astunparse.unparse(modify) + return func + + +def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor): + """ + Gain bnn_with_loss, which wraps bnn network with loss function and kl loss of each bayesian layer. + + Args: + layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. + backbone (Cell): The target network to wrap. + loss_fn (Cell): The loss function used to compute loss. + dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. + bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. + """ + bnn_loss_func = _generate_kl_loss_func(layer_count) + path = os.path.dirname(mindspore.__file__) + bnn_loss_file = tempfile.NamedTemporaryFile(mode='w+t', suffix='.py', delete=True, + dir=path + '/nn/probability/transforms/bnn_loss') + bnn_loss_file.write(bnn_loss_func) + bnn_loss_file.seek(0) + + sys.path.append(path + '/nn/probability/transforms/bnn_loss') + + module_name = os.path.basename(bnn_loss_file.name)[0:-3] + bnn_loss_module = importlib.import_module(module_name, __package__) + bnn_with_loss = bnn_loss_module.WithBNNLossCell(backbone, loss_fn, dnn_factor, bnn_factor) + return bnn_with_loss, bnn_loss_file diff --git a/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py b/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py new file mode 100644 index 0000000000..baf2a61f4d --- /dev/null +++ b/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Original WithBNNLossCell for ast to rewrite.""" + +import mindspore.nn as nn +from mindspore.nn.probability.bnn_layers.conv_variational import _ConvVariational +from mindspore.nn.probability.bnn_layers.dense_variational import _DenseVariational + + +class WithBNNLossCell(nn.Cell): + """ + Cell with loss function. + + Wraps the network with loss function. This Cell accepts data, label, backbone_factor and kl_factor as inputs and + the computed loss will be returned. + """ + def __init__(self, backbone, loss_fn, backbone_factor=1, kl_factor=1): + super(WithBNNLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + self.backbone_factor = backbone_factor + self.kl_factor = kl_factor + self.kl_loss = [] + self._add_kl_loss(self._backbone) + + def construct(self, x, label): + y_pred = self._backbone(x) + backbone_loss = self._loss_fn(y_pred, label) + kl_loss = self.cal_kl_loss() + loss = backbone_loss*self.backbone_factor + kl_loss*self.kl_factor + return loss + + def cal_kl_loss(self): + """Calculate kl loss.""" + loss = 0.0 + return loss + + def _add_kl_loss(self, net): + """Collect kl loss of each Bayesian layer.""" + for (_, layer) in net.name_cells().items(): + if isinstance(layer, (_DenseVariational, _ConvVariational)): + self.kl_loss.append(layer.compute_kl_loss) + else: + self._add_kl_loss(layer) diff --git a/mindspore/nn/probability/transforms/transform_bnn.py b/mindspore/nn/probability/transforms/transform_bnn.py new file mode 100644 index 0000000000..debbbc7179 --- /dev/null +++ b/mindspore/nn/probability/transforms/transform_bnn.py @@ -0,0 +1,246 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transform DNN to BNN.""" +import mindspore.nn as nn +from ...wrap.cell_wrapper import TrainOneStepCell +from ....nn import optim +from ....nn import layer +from .bnn_loss.generate_kl_loss import gain_bnn_with_loss +from ...probability import bnn_layers +from ..bnn_layers.conv_variational import ConvReparam +from ..bnn_layers.dense_variational import DenseReparam + +__all__ = ['TransformToBNN'] + + +class TransformToBNN: + r""" + Transform Deep Neural Network (DNN) model to Bayesian Neural Network (BNN) model. + + Args: + trainable_dnn (Cell): A trainable DNN model (backbone) wrapped by TrainOneStepCell. + dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. + bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') + >>> self.bn = nn.BatchNorm2d(64) + >>> self.relu = nn.ReLU() + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 + >>> + >>> def construct(self, x): + >>> x = self.conv(x) + >>> x = self.bn(x) + >>> x = self.relu(x) + >>> x = self.flatten(x) + >>> out = self.fc(x) + >>> return out + >>> + >>> net = Net() + >>> criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> net_with_loss = WithLossCell(network, criterion) + >>> train_network = TrainOneStepCell(net_with_loss, optim) + >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1) + """ + + def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): + net_with_loss = trainable_dnn.network + self.optimizer = trainable_dnn.optimizer + self.backbone = net_with_loss.backbone_network + self.loss_fn = getattr(net_with_loss, "_loss_fn") + self.dnn_factor = dnn_factor + self.bnn_factor = bnn_factor + self.bnn_loss_file = None + + def transform_to_bnn_model(self, + get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias, + "out_channels": dp.out_channels, "activation": dp.activation}, + get_conv_args=lambda dp: {"in_channels": dp.in_channels, "out_channels": dp.out_channels, + "pad_mode": dp.pad_mode, "kernel_size": dp.kernel_size, + "stride": dp.stride, "has_bias": dp.has_bias, + "padding": dp.padding, "dilation": dp.dilation, + "group": dp.group}, + add_dense_args=None, + add_conv_args=None): + r""" + Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell. + + Args: + get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp: + {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}. + get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: + {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, + "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. + add_dense_args (dict): The new arguments added to BNN full connection layer. Default: {}. + add_conv_args (dict): The new arguments added to BNN convolutional layer. Default: {}. + + Returns: + Cell, a trainable BNN model wrapped by TrainOneStepCell. + + Examples: + >>> net = Net() + >>> criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> net_with_loss = WithLossCell(network, criterion) + >>> train_network = TrainOneStepCell(net_with_loss, optim) + >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1) + >>> train_bnn_network = bnn_transformer.transform_to_bnn_model() + """ + if not add_dense_args: + add_dense_args = {} + if not add_conv_args: + add_conv_args = {} + + layer_count = self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, + add_conv_args) + + # rename layers of BNN model to prevent duplication of names + for value, param in self.backbone.parameters_and_names(): + param.name = value + + bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, + self.dnn_factor, self.bnn_factor) + bnn_optimizer = self._create_optimizer_with_bnn_params() + train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) + return train_bnn_network + + def transform_to_bnn_layer(self, dnn_layer_type, bnn_layer_type, get_args=None, add_args=None): + r""" + Transform a specific type of layers in DNN model to corresponding BNN layer. + + Args: + dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are + nn.Dense, nn.Conv2d. + bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are + DenseReparameterization, ConvReparameterization. + get_args (dict): The arguments gotten from the DNN layer. Default: None. + add_args (dict): The new arguments added to BNN layer. Default: None. + + Returns: + Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the + corresponding bayesian layer. + + Examples: + >>> net = Net() + >>> criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> net_with_loss = WithLossCell(network, criterion) + >>> train_network = TrainOneStepCell(net_with_loss, optim) + >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1) + >>> train_bnn_network = bnn_transformer.transform_to_bnn_layer(Dense, DenseReparam) + """ + if dnn_layer_type.__name__ not in ["Dense", "Conv2d"]: + raise ValueError(' \'dnn_layer\'' + str(dnn_layer_type) + + ', should be one of values in \'nn.Dense\', \'nn.Conv2d\'.') + + if bnn_layer_type.__name__ not in ["DenseReparam", "ConvReparam"]: + raise ValueError(' \'bnn_layer\'' + str(bnn_layer_type) + + ', should be one of values in \'DenseReparam\', \'ConvReparam\'.') + + dnn_layer_type = getattr(layer, dnn_layer_type.__name__) + bnn_layer_type = getattr(bnn_layers, bnn_layer_type.__name__) + + if not get_args: + if dnn_layer_type.__name__ == "Dense": + get_args = self._get_dense_args + else: + get_args = self._get_conv_args + + if not add_args: + add_args = {} + + layer_count = self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, + add_args) + for value, param in self.backbone.parameters_and_names(): + param.name = value + + bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, + self.dnn_factor, self.bnn_factor) + bnn_optimizer = self._create_optimizer_with_bnn_params() + + train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) + return train_bnn_network + + def _get_dense_args(self, dense_layer): + """Get arguments from dense layer.""" + dense_args = {"in_channels": dense_layer.in_channels, "has_bias": dense_layer.has_bias, + "out_channels": dense_layer.out_channels, "activation": dense_layer.activation} + return dense_args + + def _get_conv_args(self, conv_layer): + """Get arguments from conv2d layer.""" + conv_args = {"in_channels": conv_layer.in_channels, "out_channels": conv_layer.out_channels, + "pad_mode": conv_layer.pad_mode, "kernel_size": conv_layer.kernel_size, + "stride": conv_layer.stride, "has_bias": conv_layer.has_bias, + "padding": conv_layer.padding, "dilation": conv_layer.dilation, + "group": conv_layer.group} + return conv_args + + def _create_optimizer_with_bnn_params(self): + """Create new optimizer that contains bnn trainable parameters.""" + name = self.optimizer.__class__.__name__ + modules = optim.__all__ + + if name not in modules: + raise TypeError('The optimizer can be {}, but got {}'.format(str(modules), name)) + + optimizer = getattr(optim, name) + + args = {'params': self.backbone.trainable_params()} + params = optimizer.__init__.__code__.co_varnames + _params = self.optimizer.__dict__['_params'] + for param in params: + if param in _params: + args[param] = self.optimizer.__getattr__(param).data.asnumpy().tolist() + + new_optimizer = optimizer(**args) + return new_optimizer + + def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args): + """Replace both dense layer and conv2d layer in DNN model to bayesian layers.""" + count = 0 + for name, cell in backbone.name_cells().items(): + if isinstance(cell, nn.Dense): + dense_args = get_dense_args(cell) + new_layer = DenseReparam(**dense_args, **add_dense_args) + setattr(backbone, name, new_layer) + count += 1 + elif isinstance(cell, nn.Conv2d): + conv_args = get_conv_args(cell) + new_layer = ConvReparam(**conv_args, **add_conv_args) + setattr(backbone, name, new_layer) + count += 1 + else: + count += self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args, + add_conv_args) + return count + + def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args): + """Convert a specific type of layers in DNN model to corresponding bayesian layers.""" + count = 0 + for name, cell in backbone.name_cells().items(): + if isinstance(cell, dnn_layer): + args = get_args(cell) + new_layer = bnn_layer(**args, **add_args) + setattr(backbone, name, new_layer) + count += 1 + else: + count += self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args) + return count