|
|
|
@ -15,7 +15,7 @@
|
|
|
|
|
"""Toolbox for Uncertainty Evaluation."""
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from mindspore._checkparam import check_int_positive
|
|
|
|
|
from mindspore._checkparam import check_int_positive, check_bool
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.train import Model
|
|
|
|
@ -36,7 +36,8 @@ class UncertaintyEvaluation:
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model (Cell): The model for uncertainty evaluation.
|
|
|
|
|
train_dataset (Dataset): A dataset iterator.
|
|
|
|
|
epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty.
|
|
|
|
|
ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
|
|
|
|
|
task_type (str): Option for the task types of model
|
|
|
|
|
- regression: A regression model.
|
|
|
|
|
- classification: A classification model.
|
|
|
|
@ -55,9 +56,11 @@ class UncertaintyEvaluation:
|
|
|
|
|
>>> network = LeNet()
|
|
|
|
|
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
|
|
|
|
>>> load_param_into_net(network, param_dict)
|
|
|
|
|
>>> ds_train = create_dataset('workspace/mnist/train')
|
|
|
|
|
>>> epi_ds_train = create_dataset('workspace/mnist/train')
|
|
|
|
|
>>> ale_ds_train = create_dataset('workspace/mnist/train')
|
|
|
|
|
>>> evaluation = UncertaintyEvaluation(model=network,
|
|
|
|
|
>>> train_dataset=ds_train,
|
|
|
|
|
>>> epi_train_dataset=epi_ds_train,
|
|
|
|
|
>>> ale_train_dataset=ale_ds_train,
|
|
|
|
|
>>> task_type='classification',
|
|
|
|
|
>>> num_classes=10,
|
|
|
|
|
>>> epochs=1,
|
|
|
|
@ -68,28 +71,30 @@ class UncertaintyEvaluation:
|
|
|
|
|
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
|
|
|
|
|
def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1,
|
|
|
|
|
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
|
|
|
|
|
self.model = model
|
|
|
|
|
self.train_dataset = train_dataset
|
|
|
|
|
self.epi_model = model
|
|
|
|
|
self.ale_model = model
|
|
|
|
|
self.epi_train_dataset = epi_train_dataset
|
|
|
|
|
self.ale_train_dataset = ale_train_dataset
|
|
|
|
|
self.task_type = task_type
|
|
|
|
|
self.num_classes = check_int_positive(num_classes)
|
|
|
|
|
self.epochs = epochs
|
|
|
|
|
self.epochs = check_int_positive(epochs)
|
|
|
|
|
self.epi_uncer_model_path = epi_uncer_model_path
|
|
|
|
|
self.ale_uncer_model_path = ale_uncer_model_path
|
|
|
|
|
self.save_model = save_model
|
|
|
|
|
self.save_model = check_bool(save_model)
|
|
|
|
|
self.epi_uncer_model = None
|
|
|
|
|
self.ale_uncer_model = None
|
|
|
|
|
self.concat = P.Concat(axis=0)
|
|
|
|
|
self.sum = P.ReduceSum()
|
|
|
|
|
self.pow = P.Pow()
|
|
|
|
|
if self.task_type not in ('regression', 'classification'):
|
|
|
|
|
if not isinstance(model, Cell):
|
|
|
|
|
raise TypeError('The model should be Cell type.')
|
|
|
|
|
if task_type not in ('regression', 'classification'):
|
|
|
|
|
raise ValueError('The task should be regression or classification.')
|
|
|
|
|
if self.task_type == 'classification':
|
|
|
|
|
if self.num_classes is None:
|
|
|
|
|
raise ValueError("Classification task needs to input labels.")
|
|
|
|
|
if self.save_model:
|
|
|
|
|
if self.epi_uncer_model_path is None or self.ale_uncer_model_path is None:
|
|
|
|
|
if task_type == 'classification':
|
|
|
|
|
self.num_classes = check_int_positive(num_classes)
|
|
|
|
|
if save_model:
|
|
|
|
|
if epi_uncer_model_path is None or ale_uncer_model_path is None:
|
|
|
|
|
raise ValueError("If save_model is True, the epi_uncer_model_path and "
|
|
|
|
|
"ale_uncer_model_path should not be None.")
|
|
|
|
|
|
|
|
|
@ -102,7 +107,7 @@ class UncertaintyEvaluation:
|
|
|
|
|
Get the model which can obtain the epistemic uncertainty.
|
|
|
|
|
"""
|
|
|
|
|
if self.epi_uncer_model is None:
|
|
|
|
|
self.epi_uncer_model = EpistemicUncertaintyModel(self.model)
|
|
|
|
|
self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
|
|
|
|
|
if self.epi_uncer_model.drop_count == 0:
|
|
|
|
|
if self.task_type == 'classification':
|
|
|
|
|
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
|
|
|
@ -117,9 +122,9 @@ class UncertaintyEvaluation:
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
|
|
|
|
|
directory=self.epi_uncer_model_path,
|
|
|
|
|
config=config_ck)
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
|
|
|
|
model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
|
|
|
|
elif self.epi_uncer_model_path is None:
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
else:
|
|
|
|
|
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
|
|
|
|
|
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
|
|
|
|
@ -148,7 +153,7 @@ class UncertaintyEvaluation:
|
|
|
|
|
Get the model which can obtain the aleatoric uncertainty.
|
|
|
|
|
"""
|
|
|
|
|
if self.ale_uncer_model is None:
|
|
|
|
|
self.ale_uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
|
|
|
|
|
self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
|
|
|
|
|
net_loss = AleatoricLoss(self.task_type)
|
|
|
|
|
net_opt = Adam(self.ale_uncer_model.trainable_params())
|
|
|
|
|
if self.task_type == 'classification':
|
|
|
|
@ -160,9 +165,9 @@ class UncertaintyEvaluation:
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
|
|
|
|
|
directory=self.ale_uncer_model_path,
|
|
|
|
|
config=config_ck)
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
|
|
|
|
model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
|
|
|
|
elif self.ale_uncer_model_path is None:
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
else:
|
|
|
|
|
uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
|
|
|
|
|
load_param_into_net(self.ale_uncer_model, uncer_param_dict)
|
|
|
|
@ -216,31 +221,31 @@ class EpistemicUncertaintyModel(Cell):
|
|
|
|
|
<https://arxiv.org/abs/1506.02142>`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model):
|
|
|
|
|
def __init__(self, epi_model):
|
|
|
|
|
super(EpistemicUncertaintyModel, self).__init__()
|
|
|
|
|
self.drop_count = 0
|
|
|
|
|
self.model = self._make_epistemic(model)
|
|
|
|
|
self.epi_model = self._make_epistemic(epi_model)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
x = self.model(x)
|
|
|
|
|
x = self.epi_model(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def _make_epistemic(self, model, dropout_rate=0.5):
|
|
|
|
|
def _make_epistemic(self, epi_model, dropout_rate=0.5):
|
|
|
|
|
"""
|
|
|
|
|
The dropout rate is set to 0.5 by default.
|
|
|
|
|
"""
|
|
|
|
|
for (name, layer) in model.name_cells().items():
|
|
|
|
|
for (name, layer) in epi_model.name_cells().items():
|
|
|
|
|
if isinstance(layer, Dropout):
|
|
|
|
|
self.drop_count += 1
|
|
|
|
|
return model
|
|
|
|
|
for (name, layer) in model.name_cells().items():
|
|
|
|
|
return epi_model
|
|
|
|
|
for (name, layer) in epi_model.name_cells().items():
|
|
|
|
|
if isinstance(layer, (Conv2d, Dense)):
|
|
|
|
|
uncertainty_layer = layer
|
|
|
|
|
uncertainty_name = name
|
|
|
|
|
drop = Dropout(keep_prob=dropout_rate)
|
|
|
|
|
bnn_drop = SequentialCell([uncertainty_layer, drop])
|
|
|
|
|
setattr(model, uncertainty_name, bnn_drop)
|
|
|
|
|
return model
|
|
|
|
|
setattr(epi_model, uncertainty_name, bnn_drop)
|
|
|
|
|
return epi_model
|
|
|
|
|
raise ValueError("The model has not Dense Layer or Convolution Layer, "
|
|
|
|
|
"it can not evaluate epistemic uncertainty so far.")
|
|
|
|
|
|
|
|
|
@ -254,40 +259,40 @@ class AleatoricUncertaintyModel(Cell):
|
|
|
|
|
<https://arxiv.org/abs/1703.04977>`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, labels, task):
|
|
|
|
|
def __init__(self, ale_model, num_classes, task):
|
|
|
|
|
super(AleatoricUncertaintyModel, self).__init__()
|
|
|
|
|
self.task = task
|
|
|
|
|
if task == 'classification':
|
|
|
|
|
self.model = model
|
|
|
|
|
self.var_layer = Dense(labels, labels)
|
|
|
|
|
self.ale_model = ale_model
|
|
|
|
|
self.var_layer = Dense(num_classes, num_classes)
|
|
|
|
|
else:
|
|
|
|
|
self.model, self.var_layer, self.pred_layer = self._make_aleatoric(model)
|
|
|
|
|
self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(ale_model)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if self.task == 'classification':
|
|
|
|
|
pred = self.model(x)
|
|
|
|
|
pred = self.ale_model(x)
|
|
|
|
|
var = self.var_layer(pred)
|
|
|
|
|
else:
|
|
|
|
|
x = self.model(x)
|
|
|
|
|
x = self.ale_model(x)
|
|
|
|
|
pred = self.pred_layer(x)
|
|
|
|
|
var = self.var_layer(x)
|
|
|
|
|
return pred, var
|
|
|
|
|
|
|
|
|
|
def _make_aleatoric(self, model):
|
|
|
|
|
def _make_aleatoric(self, ale_model):
|
|
|
|
|
"""
|
|
|
|
|
In order to add variance into original loss, add var Layer after the original network.
|
|
|
|
|
"""
|
|
|
|
|
dense_layer = dense_name = None
|
|
|
|
|
for (name, layer) in model.name_cells().items():
|
|
|
|
|
for (name, layer) in ale_model.name_cells().items():
|
|
|
|
|
if isinstance(layer, Dense):
|
|
|
|
|
dense_layer = layer
|
|
|
|
|
dense_name = name
|
|
|
|
|
if dense_layer is None:
|
|
|
|
|
raise ValueError("The model has not Dense Layer, "
|
|
|
|
|
"it can not evaluate aleatoric uncertainty so far.")
|
|
|
|
|
setattr(model, dense_name, Flatten())
|
|
|
|
|
setattr(ale_model, dense_name, Flatten())
|
|
|
|
|
var_layer = Dense(dense_layer.in_channels, dense_layer.out_channels)
|
|
|
|
|
return model, var_layer, dense_layer
|
|
|
|
|
return ale_model, var_layer, dense_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AleatoricLoss(Cell):
|
|
|
|
|