|
|
|
@ -19,7 +19,7 @@ from mindspore._checkparam import check_int_positive
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.train import Model
|
|
|
|
|
from mindspore.train.callback import LossMonitor
|
|
|
|
|
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
from ...cell import Cell
|
|
|
|
|
from ...layer.basic import Dense, Flatten, Dropout
|
|
|
|
@ -43,8 +43,13 @@ class UncertaintyEvaluation:
|
|
|
|
|
num_classes (int): The number of labels of classification.
|
|
|
|
|
If the task type is classification, it must be set; if not classification, it need not to be set.
|
|
|
|
|
Default: None.
|
|
|
|
|
epochs (int): Total number of iterations on the data. Default: None.
|
|
|
|
|
uncertainty_model_path (str): The save or read path of the uncertainty model.
|
|
|
|
|
epochs (int): Total number of iterations on the data. Default: 1.
|
|
|
|
|
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model.
|
|
|
|
|
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model.
|
|
|
|
|
save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path
|
|
|
|
|
and ale_uncer_model_path should not be None. If False, give the path of
|
|
|
|
|
the uncertainty model, it will load the model to evaluate, if not given
|
|
|
|
|
the path, it will not save or load the uncertainty model.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> network = LeNet()
|
|
|
|
@ -55,21 +60,26 @@ class UncertaintyEvaluation:
|
|
|
|
|
>>> train_dataset=ds_train,
|
|
|
|
|
>>> task_type='classification',
|
|
|
|
|
>>> num_classes=10,
|
|
|
|
|
>>> epochs=5,
|
|
|
|
|
>>> uncertainty_model_path=None)
|
|
|
|
|
>>> epistemic_uncertainty = evaluation.eval_epistemic(eval_data)
|
|
|
|
|
>>> aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data)
|
|
|
|
|
>>> epochs=1,
|
|
|
|
|
>>> epi_uncer_model_path=None,
|
|
|
|
|
>>> ale_uncer_model_path=None,
|
|
|
|
|
>>> save_model=False)
|
|
|
|
|
>>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
|
|
|
|
|
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=None,
|
|
|
|
|
uncertainty_model_path=None):
|
|
|
|
|
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
|
|
|
|
|
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
|
|
|
|
|
self.model = model
|
|
|
|
|
self.train_dataset = train_dataset
|
|
|
|
|
self.task_type = task_type
|
|
|
|
|
self.num_classes = check_int_positive(num_classes)
|
|
|
|
|
self.epochs = epochs
|
|
|
|
|
self.uncer_model_path = uncertainty_model_path
|
|
|
|
|
self.uncer_model = None
|
|
|
|
|
self.epi_uncer_model_path = epi_uncer_model_path
|
|
|
|
|
self.ale_uncer_model_path = ale_uncer_model_path
|
|
|
|
|
self.save_model = save_model
|
|
|
|
|
self.epi_uncer_model = None
|
|
|
|
|
self.ale_uncer_model = None
|
|
|
|
|
self.concat = P.Concat(axis=0)
|
|
|
|
|
self.sum = P.ReduceSum()
|
|
|
|
|
self.pow = P.Pow()
|
|
|
|
@ -78,6 +88,10 @@ class UncertaintyEvaluation:
|
|
|
|
|
if self.task_type == 'classification':
|
|
|
|
|
if self.num_classes is None:
|
|
|
|
|
raise ValueError("Classification task needs to input labels.")
|
|
|
|
|
if self.save_model:
|
|
|
|
|
if self.epi_uncer_model_path is None or self.ale_uncer_model_path is None:
|
|
|
|
|
raise ValueError("If save_model is True, the epi_uncer_model_path and "
|
|
|
|
|
"ale_uncer_model_path should not be None.")
|
|
|
|
|
|
|
|
|
|
def _uncertainty_normalize(self, data):
|
|
|
|
|
area = np.max(data) - np.min(data)
|
|
|
|
@ -87,31 +101,38 @@ class UncertaintyEvaluation:
|
|
|
|
|
"""
|
|
|
|
|
Get the model which can obtain the epistemic uncertainty.
|
|
|
|
|
"""
|
|
|
|
|
if self.uncer_model and self.uncer_model_path is None:
|
|
|
|
|
self.uncer_model = EpistemicUncertaintyModel(self.model)
|
|
|
|
|
if self.uncer_model.drop_count == 0:
|
|
|
|
|
if self.epi_uncer_model is None:
|
|
|
|
|
self.epi_uncer_model = EpistemicUncertaintyModel(self.model)
|
|
|
|
|
if self.epi_uncer_model.drop_count == 0:
|
|
|
|
|
if self.task_type == 'classification':
|
|
|
|
|
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
|
|
|
|
net_opt = Adam(self.uncer_model.trainable_params())
|
|
|
|
|
model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
|
net_opt = Adam(self.epi_uncer_model.trainable_params())
|
|
|
|
|
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
|
else:
|
|
|
|
|
net_loss = MSELoss()
|
|
|
|
|
net_opt = Adam(self.uncer_model.trainable_params())
|
|
|
|
|
model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
elif self.uncer_model is None:
|
|
|
|
|
uncer_param_dict = load_checkpoint(self.uncer_model_path)
|
|
|
|
|
load_param_into_net(self.uncer_model, uncer_param_dict)
|
|
|
|
|
net_opt = Adam(self.epi_uncer_model.trainable_params())
|
|
|
|
|
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
|
|
|
|
|
if self.save_model:
|
|
|
|
|
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
|
|
|
|
|
directory=self.epi_uncer_model_path,
|
|
|
|
|
config=config_ck)
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
|
|
|
|
elif self.epi_uncer_model_path is None:
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
else:
|
|
|
|
|
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
|
|
|
|
|
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
|
|
|
|
|
|
|
|
|
|
def _eval_epistemic_uncertainty(self, eval_data, mc=10):
|
|
|
|
|
"""
|
|
|
|
|
Evaluate the epistemic uncertainty of classification and regression models using MC dropout.
|
|
|
|
|
"""
|
|
|
|
|
self._get_epistemic_uncertainty_model()
|
|
|
|
|
self.uncer_model.set_train(True)
|
|
|
|
|
self.epi_uncer_model.set_train(True)
|
|
|
|
|
outputs = [None] * mc
|
|
|
|
|
for i in range(mc):
|
|
|
|
|
pred = self.uncer_model(eval_data)
|
|
|
|
|
pred = self.epi_uncer_model(eval_data)
|
|
|
|
|
outputs[i] = pred.asnumpy()
|
|
|
|
|
if self.task_type == 'classification':
|
|
|
|
|
outputs = np.stack(outputs, axis=2)
|
|
|
|
@ -126,30 +147,37 @@ class UncertaintyEvaluation:
|
|
|
|
|
"""
|
|
|
|
|
Get the model which can obtain the aleatoric uncertainty.
|
|
|
|
|
"""
|
|
|
|
|
if self.uncer_model and self.uncer_model_path is None:
|
|
|
|
|
self.uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
|
|
|
|
|
if self.ale_uncer_model is None:
|
|
|
|
|
self.ale_uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
|
|
|
|
|
net_loss = AleatoricLoss(self.task_type)
|
|
|
|
|
net_opt = Adam(self.uncer_model.trainable_params())
|
|
|
|
|
net_opt = Adam(self.ale_uncer_model.trainable_params())
|
|
|
|
|
if self.task_type == 'classification':
|
|
|
|
|
model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
|
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
|
else:
|
|
|
|
|
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
|
|
|
|
|
if self.save_model:
|
|
|
|
|
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
|
|
|
|
|
directory=self.ale_uncer_model_path,
|
|
|
|
|
config=config_ck)
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
|
|
|
|
elif self.ale_uncer_model_path is None:
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
else:
|
|
|
|
|
model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
|
|
|
|
|
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
|
|
|
|
elif self.uncer_model is None:
|
|
|
|
|
uncer_param_dict = load_checkpoint(self.uncer_model_path)
|
|
|
|
|
load_param_into_net(self.uncer_model, uncer_param_dict)
|
|
|
|
|
uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
|
|
|
|
|
load_param_into_net(self.ale_uncer_model, uncer_param_dict)
|
|
|
|
|
|
|
|
|
|
def _eval_aleatoric_uncertainty(self, eval_data):
|
|
|
|
|
"""
|
|
|
|
|
Evaluate the aleatoric uncertainty of classification and regression models.
|
|
|
|
|
"""
|
|
|
|
|
self._get_aleatoric_uncertainty_model()
|
|
|
|
|
_, var = self.uncer_model(eval_data)
|
|
|
|
|
_, var = self.ale_uncer_model(eval_data)
|
|
|
|
|
ale_uncertainty = self.sum(self.pow(var, 2), 1)
|
|
|
|
|
ale_uncertainty = self._uncertainty_normalize(ale_uncertainty.asnumpy())
|
|
|
|
|
return ale_uncertainty
|
|
|
|
|
|
|
|
|
|
def eval_epistemic(self, eval_data):
|
|
|
|
|
def eval_epistemic_uncertainty(self, eval_data):
|
|
|
|
|
"""
|
|
|
|
|
Evaluate the epistemic uncertainty of inference results, which also called model uncertainty.
|
|
|
|
|
|
|
|
|
@ -159,10 +187,10 @@ class UncertaintyEvaluation:
|
|
|
|
|
Returns:
|
|
|
|
|
numpy.dtype, the epistemic uncertainty of inference results of data samples.
|
|
|
|
|
"""
|
|
|
|
|
uncertainty = self._eval_aleatoric_uncertainty(eval_data)
|
|
|
|
|
uncertainty = self._eval_epistemic_uncertainty(eval_data)
|
|
|
|
|
return uncertainty
|
|
|
|
|
|
|
|
|
|
def eval_aleatoric(self, eval_data):
|
|
|
|
|
def eval_aleatoric_uncertainty(self, eval_data):
|
|
|
|
|
"""
|
|
|
|
|
Evaluate the aleatoric uncertainty of inference results, which also called data uncertainty.
|
|
|
|
|
|
|
|
|
@ -172,7 +200,7 @@ class UncertaintyEvaluation:
|
|
|
|
|
Returns:
|
|
|
|
|
numpy.dtype, the aleatoric uncertainty of inference results of data samples.
|
|
|
|
|
"""
|
|
|
|
|
uncertainty = self._eval_epistemic_uncertainty(eval_data)
|
|
|
|
|
uncertainty = self._eval_aleatoric_uncertainty(eval_data)
|
|
|
|
|
return uncertainty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|