diff --git a/mindspore/nn/probability/toolbox/__init__.py b/mindspore/nn/probability/toolbox/__init__.py new file mode 100644 index 0000000000..8391cd9185 --- /dev/null +++ b/mindspore/nn/probability/toolbox/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Uncertainty toolbox. +""" + +from .uncertainty_evaluation import UncertaintyEvaluation + +__all__ = ['UncertaintyEvaluation'] diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py new file mode 100644 index 0000000000..173f641a05 --- /dev/null +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -0,0 +1,298 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Toolbox for Uncertainty Evaluation.""" +import numpy as np + +from mindspore._checkparam import check_int_positive +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.train import Model +from mindspore.train.callback import LossMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from ...cell import Cell +from ...layer.basic import Dense, Flatten, Dropout +from ...layer.conv import Conv2d +from ...layer.container import SequentialCell +from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss +from ...metrics import Accuracy, MSE +from ...optim import Adam + + +class UncertaintyEvaluation: + r""" + Toolbox for Uncertainty Evaluation. + + Args: + model (Cell): The model for uncertainty evaluation. + train_dataset (Dataset): A dataset iterator. + task_type (str): Option for the task types of model + - regression: A regression model. + - classification: A classification model. + num_classes (int): The number of labels of classification. + If the task type is classification, it must be set; if not classification, it need not to be set. + Default: None. + epochs (int): Total number of iterations on the data. Default: None. + uncertainty_model_path (str): The save or read path of the uncertainty model. + + Examples: + >>> network = LeNet() + >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt') + >>> load_param_into_net(network, param_dict) + >>> ds_train = create_dataset('workspace/mnist/train') + >>> evaluation = UncertaintyEvaluation(model=network, + >>> train_dataset=ds_train, + >>> task_type='classification', + >>> num_classes=10, + >>> epochs=5, + >>> uncertainty_model_path=None) + >>> epistemic_uncertainty = evaluation.eval_epistemic(eval_data) + >>> aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data) + """ + + def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=None, + uncertainty_model_path=None): + self.model = model + self.train_dataset = train_dataset + self.task_type = task_type + self.num_classes = check_int_positive(num_classes) + self.epochs = epochs + self.uncer_model_path = uncertainty_model_path + self.uncer_model = None + self.concat = P.Concat(axis=0) + self.sum = P.ReduceSum() + self.pow = P.Pow() + if self.task_type not in ('regression', 'classification'): + raise ValueError('The task should be regression or classification.') + if self.task_type == 'classification': + if self.num_classes is None: + raise ValueError("Classification task needs to input labels.") + + def _uncertainty_normalize(self, data): + area = np.max(data) - np.min(data) + return (data - np.min(data)) / area + + def _get_epistemic_uncertainty_model(self): + """ + Get the model which can obtain the epistemic uncertainty. + """ + if self.uncer_model and self.uncer_model_path is None: + self.uncer_model = EpistemicUncertaintyModel(self.model) + if self.uncer_model.drop_count == 0: + if self.task_type == 'classification': + net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = Adam(self.uncer_model.trainable_params()) + model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + else: + net_loss = MSELoss() + net_opt = Adam(self.uncer_model.trainable_params()) + model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()}) + model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()]) + elif self.uncer_model is None: + uncer_param_dict = load_checkpoint(self.uncer_model_path) + load_param_into_net(self.uncer_model, uncer_param_dict) + + def _eval_epistemic_uncertainty(self, eval_data, mc=10): + """ + Evaluate the epistemic uncertainty of classification and regression models using MC dropout. + """ + self._get_epistemic_uncertainty_model() + self.uncer_model.set_train(True) + outputs = [None] * mc + for i in range(mc): + pred = self.uncer_model(eval_data) + outputs[i] = pred.asnumpy() + if self.task_type == 'classification': + outputs = np.stack(outputs, axis=2) + epi_uncertainty = outputs.var(axis=2) + else: + outputs = np.stack(outputs, axis=1) + epi_uncertainty = outputs.var(axis=1) + epi_uncertainty = self._uncertainty_normalize(np.array(epi_uncertainty)) + return epi_uncertainty + + def _get_aleatoric_uncertainty_model(self): + """ + Get the model which can obtain the aleatoric uncertainty. + """ + if self.uncer_model and self.uncer_model_path is None: + self.uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type) + net_loss = AleatoricLoss(self.task_type) + net_opt = Adam(self.uncer_model.trainable_params()) + if self.task_type == 'classification': + model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + else: + model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()}) + model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()]) + elif self.uncer_model is None: + uncer_param_dict = load_checkpoint(self.uncer_model_path) + load_param_into_net(self.uncer_model, uncer_param_dict) + + def _eval_aleatoric_uncertainty(self, eval_data): + """ + Evaluate the aleatoric uncertainty of classification and regression models. + """ + self._get_aleatoric_uncertainty_model() + _, var = self.uncer_model(eval_data) + ale_uncertainty = self.sum(self.pow(var, 2), 1) + ale_uncertainty = self._uncertainty_normalize(ale_uncertainty.asnumpy()) + return ale_uncertainty + + def eval_epistemic(self, eval_data): + """ + Evaluate the epistemic uncertainty of inference results, which also called model uncertainty. + + Args: + eval_data (Tensor): The data samples to be evaluated, the shape should be (N,C,H,W). + + Returns: + numpy.dtype, the epistemic uncertainty of inference results of data samples. + """ + uncertainty = self._eval_aleatoric_uncertainty(eval_data) + return uncertainty + + def eval_aleatoric(self, eval_data): + """ + Evaluate the aleatoric uncertainty of inference results, which also called data uncertainty. + + Args: + eval_data (Tensor): The data samples to be evaluated, the shape should be (N,C,H,W). + + Returns: + numpy.dtype, the aleatoric uncertainty of inference results of data samples. + """ + uncertainty = self._eval_epistemic_uncertainty(eval_data) + return uncertainty + + +class EpistemicUncertaintyModel(Cell): + """ + Using dropout during training and eval time which is approximate bayesian inference. In this way, + we can obtain the epistemic uncertainty (also called model uncertainty). + + If the original model has Dropout layer, just use dropout when eval time, if not, add dropout layer + after Dense layer or Conv layer, then use dropout during train and eval time. + + See more details in `Dropout as a Bayesian Approximation: Representing Model uncertainty in Deep Learning + `. + """ + + def __init__(self, model): + super(EpistemicUncertaintyModel, self).__init__() + self.drop_count = 0 + self.model = self._make_epistemic(model) + + def construct(self, x): + x = self.model(x) + return x + + def _make_epistemic(self, model, dropout_rate=0.5): + """ + The dropout rate is set to 0.5 by default. + """ + for (name, layer) in model.name_cells().items(): + if isinstance(layer, Dropout): + self.drop_count += 1 + return model + for (name, layer) in model.name_cells().items(): + if isinstance(layer, (Conv2d, Dense)): + uncertainty_layer = layer + uncertainty_name = name + drop = Dropout(keep_prob=dropout_rate) + bnn_drop = SequentialCell([uncertainty_layer, drop]) + setattr(model, uncertainty_name, bnn_drop) + return model + raise ValueError("The model has not Dense Layer or Convolution Layer, " + "it can not evaluate epistemic uncertainty so far.") + + +class AleatoricUncertaintyModel(Cell): + """ + The aleatoric uncertainty (also called data uncertainty) is caused by input data, to obtain this + uncertainty, the loss function should be modified in order to add variance into loss. + + See more details in `What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? + `. + """ + + def __init__(self, model, labels, task): + super(AleatoricUncertaintyModel, self).__init__() + self.task = task + if task == 'classification': + self.model = model + self.var_layer = Dense(labels, labels) + else: + self.model, self.var_layer, self.pred_layer = self._make_aleatoric(model) + + def construct(self, x): + if self.task == 'classification': + pred = self.model(x) + var = self.var_layer(pred) + else: + x = self.model(x) + pred = self.pred_layer(x) + var = self.var_layer(x) + return pred, var + + def _make_aleatoric(self, model): + """ + In order to add variance into original loss, add var Layer after the original network. + """ + dense_layer = dense_name = None + for (name, layer) in model.name_cells().items(): + if isinstance(layer, Dense): + dense_layer = layer + dense_name = name + if dense_layer is None: + raise ValueError("The model has not Dense Layer, " + "it can not evaluate aleatoric uncertainty so far.") + setattr(model, dense_name, Flatten()) + var_layer = Dense(dense_layer.in_channels, dense_layer.out_channels) + return model, var_layer, dense_layer + + +class AleatoricLoss(Cell): + """ + The loss function of aleatoric model, different modification methods are adopted for + classification and regression. + """ + + def __init__(self, task): + super(AleatoricLoss, self).__init__() + self.task = task + if self.task == 'classification': + self.sum = P.ReduceSum() + self.exp = P.Exp() + self.normal = C.normal + self.to_tensor = P.ScalarToArray() + self.entropy = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + else: + self.mean = P.ReduceMean() + self.exp = P.Exp() + self.pow = P.Pow() + + def construct(self, data_pred, y): + y_pred, var = data_pred + if self.task == 'classification': + sample_times = 10 + epsilon = self.normal((1, sample_times), self.to_tensor(0.0), self.to_tensor(1.0), 0) + total_loss = 0 + for i in range(sample_times): + y_pred_i = y_pred + epsilon[0][i] * var + loss = self.entropy(y_pred_i, y) + total_loss += loss + avg_loss = total_loss / sample_times + return avg_loss + loss = self.mean(0.5 * self.exp(-var) * self.pow(y - y_pred, 2) + 0.5 * var) + return loss diff --git a/tests/st/probability/test_uncertainty.py b/tests/st/probability/test_uncertainty.py new file mode 100644 index 0000000000..1037669f92 --- /dev/null +++ b/tests/st/probability/test_uncertainty.py @@ -0,0 +1,133 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test uncertainty toolbox """ +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.common import dtype as mstype +from mindspore.common.initializer import TruncatedNormal +from mindspore.dataset.transforms.vision import Inter +from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, channel=1): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ + create dataset for train or test + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + rescale_nml = 1 / 0.3081 + shift_nml = -1 * 0.1307 / 0.3081 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode + rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +if __name__ == '__main__': + # get trained model + network = LeNet5() + param_dict = load_checkpoint('checkpoint_lenet.ckpt') + load_param_into_net(network, param_dict) + # get train and eval dataset + ds_train = create_dataset('workspace/mnist/train') + ds_eval = create_dataset('workspace/mnist/test') + evaluation = UncertaintyEvaluation(model=network, + train_dataset=ds_train, + task_type='classification', + num_classes=10, + epochs=5, + uncertainty_model_path=None) + for eval_data in ds_eval.create_dict_iterator(): + eval_data = Tensor(eval_data['image'], mstype.float32) + epistemic_uncertainty = evaluation.eval_epistemic(eval_data) + aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data)