|
|
|
@ -13,18 +13,20 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Toolbox for Uncertainty Evaluation."""
|
|
|
|
|
import numpy as np
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore._checkparam import check_int_positive, check_bool
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.train import Model
|
|
|
|
|
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
|
|
|
|
|
from ...cell import Cell
|
|
|
|
|
from ...layer.basic import Dense, Flatten, Dropout
|
|
|
|
|
from ...layer.conv import Conv2d
|
|
|
|
|
from ...layer.container import SequentialCell
|
|
|
|
|
from ...layer.conv import Conv2d
|
|
|
|
|
from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss
|
|
|
|
|
from ...metrics import Accuracy, MSE
|
|
|
|
|
from ...optim import Adam
|
|
|
|
@ -36,8 +38,7 @@ class UncertaintyEvaluation:
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model (Cell): The model for uncertainty evaluation.
|
|
|
|
|
epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty.
|
|
|
|
|
ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
|
|
|
|
|
train_dataset (Dataset): A dataset iterator to train model.
|
|
|
|
|
task_type (str): Option for the task types of model
|
|
|
|
|
- regression: A regression model.
|
|
|
|
|
- classification: A classification model.
|
|
|
|
@ -45,22 +46,20 @@ class UncertaintyEvaluation:
|
|
|
|
|
If the task type is classification, it must be set; if not classification, it need not to be set.
|
|
|
|
|
Default: None.
|
|
|
|
|
epochs (int): Total number of iterations on the data. Default: 1.
|
|
|
|
|
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model.
|
|
|
|
|
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model.
|
|
|
|
|
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None.
|
|
|
|
|
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None.
|
|
|
|
|
save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path
|
|
|
|
|
and ale_uncer_model_path should not be None. If False, give the path of
|
|
|
|
|
the uncertainty model, it will load the model to evaluate, if not given
|
|
|
|
|
the path, it will not save or load the uncertainty model.
|
|
|
|
|
the path, it will not save or load the uncertainty model. Default: False.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> network = LeNet()
|
|
|
|
|
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
|
|
|
|
>>> load_param_into_net(network, param_dict)
|
|
|
|
|
>>> epi_ds_train = create_dataset('workspace/mnist/train')
|
|
|
|
|
>>> ale_ds_train = create_dataset('workspace/mnist/train')
|
|
|
|
|
>>> ds_train = create_dataset('workspace/mnist/train')
|
|
|
|
|
>>> evaluation = UncertaintyEvaluation(model=network,
|
|
|
|
|
>>> epi_train_dataset=epi_ds_train,
|
|
|
|
|
>>> ale_train_dataset=ale_ds_train,
|
|
|
|
|
>>> train_dataset=ds_train,
|
|
|
|
|
>>> task_type='classification',
|
|
|
|
|
>>> num_classes=10,
|
|
|
|
|
>>> epochs=1,
|
|
|
|
@ -71,12 +70,12 @@ class UncertaintyEvaluation:
|
|
|
|
|
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1,
|
|
|
|
|
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
|
|
|
|
|
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
|
|
|
|
|
self.epi_model = model
|
|
|
|
|
self.ale_model = model
|
|
|
|
|
self.epi_train_dataset = epi_train_dataset
|
|
|
|
|
self.ale_train_dataset = ale_train_dataset
|
|
|
|
|
self.epi_train_dataset = train_dataset
|
|
|
|
|
self.ale_train_dataset = deepcopy(train_dataset)
|
|
|
|
|
self.task_type = task_type
|
|
|
|
|
self.epochs = check_int_positive(epochs)
|
|
|
|
|
self.epi_uncer_model_path = epi_uncer_model_path
|
|
|
|
@ -93,6 +92,8 @@ class UncertaintyEvaluation:
|
|
|
|
|
raise ValueError('The task should be regression or classification.')
|
|
|
|
|
if task_type == 'classification':
|
|
|
|
|
self.num_classes = check_int_positive(num_classes)
|
|
|
|
|
else:
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
if save_model:
|
|
|
|
|
if epi_uncer_model_path is None or ale_uncer_model_path is None:
|
|
|
|
|
raise ValueError("If save_model is True, the epi_uncer_model_path and "
|
|
|
|
|