update uncertainty toolbox

pull/4845/head
zhangxinfeng3 5 years ago
parent c170ccbf33
commit 27ff97a555

@ -93,18 +93,21 @@ class ConditionalVAE(Cell):
recon_x = self._decode(z_c) recon_x = self._decode(z_c)
return recon_x, x, mu, std return recon_x, x, mu, std
def generate_sample(self, sample_y, generate_nums=None, shape=None): def generate_sample(self, sample_y, generate_nums, shape):
""" """
Randomly sample from latent space to generate sample. Randomly sample from latent space to generate sample.
Args: Args:
sample_y (Tensor): Define the label of sample, int tensor. sample_y (Tensor): Define the label of sample, int tensor.
generate_nums (int): The number of samples to generate. generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`.
Returns: Returns:
Tensor, the generated sample. Tensor, the generated sample.
""" """
generate_nums = check_int_positive(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1:
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
sample_y = self.one_hot(sample_y) sample_y = self.one_hot(sample_y)
sample_c = self.concat((sample_z, sample_y)) sample_c = self.concat((sample_z, sample_y))

@ -88,11 +88,14 @@ class VAE(Cell):
Args: Args:
generate_nums (int): The number of samples to generate. generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`.
Returns: Returns:
Tensor, the generated sample. Tensor, the generated sample.
""" """
generate_nums = check_int_positive(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1:
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
sample = self._decode(sample_z) sample = self._decode(sample_z)
sample = self.reshape(sample, shape) sample = self.reshape(sample, shape)

@ -13,18 +13,20 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Toolbox for Uncertainty Evaluation.""" """Toolbox for Uncertainty Evaluation."""
import numpy as np from copy import deepcopy
import numpy as np
from mindspore._checkparam import check_int_positive, check_bool from mindspore._checkparam import check_int_positive, check_bool
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model from mindspore.train import Model
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from ...cell import Cell from ...cell import Cell
from ...layer.basic import Dense, Flatten, Dropout from ...layer.basic import Dense, Flatten, Dropout
from ...layer.conv import Conv2d
from ...layer.container import SequentialCell from ...layer.container import SequentialCell
from ...layer.conv import Conv2d
from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss
from ...metrics import Accuracy, MSE from ...metrics import Accuracy, MSE
from ...optim import Adam from ...optim import Adam
@ -36,8 +38,7 @@ class UncertaintyEvaluation:
Args: Args:
model (Cell): The model for uncertainty evaluation. model (Cell): The model for uncertainty evaluation.
epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty. train_dataset (Dataset): A dataset iterator to train model.
ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
task_type (str): Option for the task types of model task_type (str): Option for the task types of model
- regression: A regression model. - regression: A regression model.
- classification: A classification model. - classification: A classification model.
@ -45,22 +46,20 @@ class UncertaintyEvaluation:
If the task type is classification, it must be set; if not classification, it need not to be set. If the task type is classification, it must be set; if not classification, it need not to be set.
Default: None. Default: None.
epochs (int): Total number of iterations on the data. Default: 1. epochs (int): Total number of iterations on the data. Default: 1.
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None.
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None.
save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path
and ale_uncer_model_path should not be None. If False, give the path of and ale_uncer_model_path should not be None. If False, give the path of
the uncertainty model, it will load the model to evaluate, if not given the uncertainty model, it will load the model to evaluate, if not given
the path, it will not save or load the uncertainty model. the path, it will not save or load the uncertainty model. Default: False.
Examples: Examples:
>>> network = LeNet() >>> network = LeNet()
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt') >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
>>> load_param_into_net(network, param_dict) >>> load_param_into_net(network, param_dict)
>>> epi_ds_train = create_dataset('workspace/mnist/train') >>> ds_train = create_dataset('workspace/mnist/train')
>>> ale_ds_train = create_dataset('workspace/mnist/train')
>>> evaluation = UncertaintyEvaluation(model=network, >>> evaluation = UncertaintyEvaluation(model=network,
>>> epi_train_dataset=epi_ds_train, >>> train_dataset=ds_train,
>>> ale_train_dataset=ale_ds_train,
>>> task_type='classification', >>> task_type='classification',
>>> num_classes=10, >>> num_classes=10,
>>> epochs=1, >>> epochs=1,
@ -71,12 +70,12 @@ class UncertaintyEvaluation:
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
""" """
def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1, def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False): epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
self.epi_model = model self.epi_model = model
self.ale_model = model self.ale_model = model
self.epi_train_dataset = epi_train_dataset self.epi_train_dataset = train_dataset
self.ale_train_dataset = ale_train_dataset self.ale_train_dataset = deepcopy(train_dataset)
self.task_type = task_type self.task_type = task_type
self.epochs = check_int_positive(epochs) self.epochs = check_int_positive(epochs)
self.epi_uncer_model_path = epi_uncer_model_path self.epi_uncer_model_path = epi_uncer_model_path
@ -93,6 +92,8 @@ class UncertaintyEvaluation:
raise ValueError('The task should be regression or classification.') raise ValueError('The task should be regression or classification.')
if task_type == 'classification': if task_type == 'classification':
self.num_classes = check_int_positive(num_classes) self.num_classes = check_int_positive(num_classes)
else:
self.num_classes = num_classes
if save_model: if save_model:
if epi_uncer_model_path is None or ale_uncer_model_path is None: if epi_uncer_model_path is None or ale_uncer_model_path is None:
raise ValueError("If save_model is True, the epi_uncer_model_path and " raise ValueError("If save_model is True, the epi_uncer_model_path and "

@ -119,12 +119,10 @@ if __name__ == '__main__':
param_dict = load_checkpoint('checkpoint_lenet.ckpt') param_dict = load_checkpoint('checkpoint_lenet.ckpt')
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# get train and eval dataset # get train and eval dataset
epi_ds_train = create_dataset('workspace/mnist/train') ds_train = create_dataset('workspace/mnist/train')
ale_ds_train = create_dataset('workspace/mnist/train')
ds_eval = create_dataset('workspace/mnist/test') ds_eval = create_dataset('workspace/mnist/test')
evaluation = UncertaintyEvaluation(model=network, evaluation = UncertaintyEvaluation(model=network,
epi_train_dataset=epi_ds_train, train_dataset=ds_train,
ale_train_dataset=ale_ds_train,
task_type='classification', task_type='classification',
num_classes=10, num_classes=10,
epochs=1, epochs=1,

Loading…
Cancel
Save