!4409 update variational inference and toolbox

Merge pull request !4409 from zhangxinfeng3/master
pull/4409/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 17a3afe305

@ -16,7 +16,6 @@
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive
from ...distribution.normal import Normal
from ....cell import Cell
from ....layer.basic import Dense, OneHot
@ -46,7 +45,7 @@ class ConditionalVAE(Cell):
- **input_y** (Tensor) - the tensor of the target data, the shape is math:`(N, 1)`.
Outputs:
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)).
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
"""
def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes):
@ -59,11 +58,10 @@ class ConditionalVAE(Cell):
self.normal = C.normal
self.exp = P.Exp()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.concat = P.Concat(axis=1)
self.to_tensor = P.ScalarToArray()
self.normal_dis = Normal()
self.one_hot = OneHot(depth=num_classes)
self.standard_normal_dis = Normal([0] * self.latent_size, [1] * self.latent_size)
self.dense1 = Dense(self.hidden_size, self.latent_size)
self.dense2 = Dense(self.hidden_size, self.latent_size)
self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size)
@ -82,11 +80,11 @@ class ConditionalVAE(Cell):
def construct(self, x, y):
mu, log_var = self._encode(x, y)
std = self.exp(0.5 * log_var)
z = self.normal_dis('sample', mean=mu, sd=std)
z = self.normal(self.shape(mu), mu, std, seed=0)
y = self.one_hot(y)
z_c = self.concat((z, y))
recon_x = self._decode(z_c)
return recon_x, x, mu, std, z, self.standard_normal_dis
return recon_x, x, mu, std
def generate_sample(self, sample_y, generate_nums=None, shape=None):
"""

@ -16,7 +16,6 @@
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive
from ...distribution.normal import Normal
from ....cell import Cell
from ....layer.basic import Dense
@ -43,7 +42,7 @@ class VAE(Cell):
- **input** (Tensor) - the same shape as the input of encoder.
Outputs:
- **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)).
- **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
"""
def __init__(self, encoder, decoder, hidden_size, latent_size):
@ -55,9 +54,8 @@ class VAE(Cell):
self.normal = C.normal
self.exp = P.Exp()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.to_tensor = P.ScalarToArray()
self.normal_dis = Normal()
self.standard_normal_dis = Normal([0]*self.latent_size, [1]*self.latent_size)
self.dense1 = Dense(self.hidden_size, self.latent_size)
self.dense2 = Dense(self.hidden_size, self.latent_size)
self.dense3 = Dense(self.latent_size, self.hidden_size)
@ -76,9 +74,9 @@ class VAE(Cell):
def construct(self, x):
mu, log_var = self._encode(x)
std = self.exp(0.5 * log_var)
z = self.normal_dis('sample', mean=mu, sd=std)
z = self.normal(self.shape(mu), mu, std, seed=0)
recon_x = self._decode(z)
return recon_x, x, mu, std, z, self.standard_normal_dis
return recon_x, x, mu, std
def generate_sample(self, generate_nums, shape):
"""

@ -36,7 +36,7 @@ class ELBO(Cell):
- Normal: If the distribution of output data is Normal, the reconstruct loss is MSELoss.
Inputs:
- **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)).
- **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
- **target_data** (Tensor) - the target tensor.
Outputs:
@ -46,6 +46,7 @@ class ELBO(Cell):
def __init__(self, latent_prior='Normal', output_prior='Normal'):
super(ELBO, self).__init__()
self.sum = P.ReduceSum()
self.zeros = P.ZerosLike()
if latent_prior == 'Normal':
self.posterior = Normal()
else:
@ -56,9 +57,8 @@ class ELBO(Cell):
raise ValueError('The values of output_dis now only support Normal')
def construct(self, data, label):
recon_x, x, mu, std, z, prior = data
recon_x, x, mu, std = data
reconstruct_loss = self.recon_loss(x, recon_x)
kl_loss = -(prior('log_prob', z) - self.posterior('log_prob', z, mu, std)) \
* self.posterior('prob', z, mu, std)
kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu)+1, mu, std)
elbo = reconstruct_loss + self.sum(kl_loss)
return elbo

@ -19,7 +19,7 @@ from mindspore._checkparam import check_int_positive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from ...cell import Cell
from ...layer.basic import Dense, Flatten, Dropout
@ -43,8 +43,13 @@ class UncertaintyEvaluation:
num_classes (int): The number of labels of classification.
If the task type is classification, it must be set; if not classification, it need not to be set.
Default: None.
epochs (int): Total number of iterations on the data. Default: None.
uncertainty_model_path (str): The save or read path of the uncertainty model.
epochs (int): Total number of iterations on the data. Default: 1.
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model.
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model.
save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path
and ale_uncer_model_path should not be None. If False, give the path of
the uncertainty model, it will load the model to evaluate, if not given
the path, it will not save or load the uncertainty model.
Examples:
>>> network = LeNet()
@ -55,21 +60,26 @@ class UncertaintyEvaluation:
>>> train_dataset=ds_train,
>>> task_type='classification',
>>> num_classes=10,
>>> epochs=5,
>>> uncertainty_model_path=None)
>>> epistemic_uncertainty = evaluation.eval_epistemic(eval_data)
>>> aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data)
>>> epochs=1,
>>> epi_uncer_model_path=None,
>>> ale_uncer_model_path=None,
>>> save_model=False)
>>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
"""
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=None,
uncertainty_model_path=None):
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
self.model = model
self.train_dataset = train_dataset
self.task_type = task_type
self.num_classes = check_int_positive(num_classes)
self.epochs = epochs
self.uncer_model_path = uncertainty_model_path
self.uncer_model = None
self.epi_uncer_model_path = epi_uncer_model_path
self.ale_uncer_model_path = ale_uncer_model_path
self.save_model = save_model
self.epi_uncer_model = None
self.ale_uncer_model = None
self.concat = P.Concat(axis=0)
self.sum = P.ReduceSum()
self.pow = P.Pow()
@ -78,6 +88,10 @@ class UncertaintyEvaluation:
if self.task_type == 'classification':
if self.num_classes is None:
raise ValueError("Classification task needs to input labels.")
if self.save_model:
if self.epi_uncer_model_path is None or self.ale_uncer_model_path is None:
raise ValueError("If save_model is True, the epi_uncer_model_path and "
"ale_uncer_model_path should not be None.")
def _uncertainty_normalize(self, data):
area = np.max(data) - np.min(data)
@ -87,31 +101,38 @@ class UncertaintyEvaluation:
"""
Get the model which can obtain the epistemic uncertainty.
"""
if self.uncer_model and self.uncer_model_path is None:
self.uncer_model = EpistemicUncertaintyModel(self.model)
if self.uncer_model.drop_count == 0:
if self.epi_uncer_model is None:
self.epi_uncer_model = EpistemicUncertaintyModel(self.model)
if self.epi_uncer_model.drop_count == 0:
if self.task_type == 'classification':
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = Adam(self.uncer_model.trainable_params())
model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
net_opt = Adam(self.epi_uncer_model.trainable_params())
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
net_loss = MSELoss()
net_opt = Adam(self.uncer_model.trainable_params())
model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
elif self.uncer_model is None:
uncer_param_dict = load_checkpoint(self.uncer_model_path)
load_param_into_net(self.uncer_model, uncer_param_dict)
net_opt = Adam(self.epi_uncer_model.trainable_params())
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
if self.save_model:
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
directory=self.epi_uncer_model_path,
config=config_ck)
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
elif self.epi_uncer_model_path is None:
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
else:
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
def _eval_epistemic_uncertainty(self, eval_data, mc=10):
"""
Evaluate the epistemic uncertainty of classification and regression models using MC dropout.
"""
self._get_epistemic_uncertainty_model()
self.uncer_model.set_train(True)
self.epi_uncer_model.set_train(True)
outputs = [None] * mc
for i in range(mc):
pred = self.uncer_model(eval_data)
pred = self.epi_uncer_model(eval_data)
outputs[i] = pred.asnumpy()
if self.task_type == 'classification':
outputs = np.stack(outputs, axis=2)
@ -126,30 +147,37 @@ class UncertaintyEvaluation:
"""
Get the model which can obtain the aleatoric uncertainty.
"""
if self.uncer_model and self.uncer_model_path is None:
self.uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
if self.ale_uncer_model is None:
self.ale_uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
net_loss = AleatoricLoss(self.task_type)
net_opt = Adam(self.uncer_model.trainable_params())
net_opt = Adam(self.ale_uncer_model.trainable_params())
if self.task_type == 'classification':
model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
if self.save_model:
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
directory=self.ale_uncer_model_path,
config=config_ck)
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
elif self.ale_uncer_model_path is None:
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
else:
model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
elif self.uncer_model is None:
uncer_param_dict = load_checkpoint(self.uncer_model_path)
load_param_into_net(self.uncer_model, uncer_param_dict)
uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
load_param_into_net(self.ale_uncer_model, uncer_param_dict)
def _eval_aleatoric_uncertainty(self, eval_data):
"""
Evaluate the aleatoric uncertainty of classification and regression models.
"""
self._get_aleatoric_uncertainty_model()
_, var = self.uncer_model(eval_data)
_, var = self.ale_uncer_model(eval_data)
ale_uncertainty = self.sum(self.pow(var, 2), 1)
ale_uncertainty = self._uncertainty_normalize(ale_uncertainty.asnumpy())
return ale_uncertainty
def eval_epistemic(self, eval_data):
def eval_epistemic_uncertainty(self, eval_data):
"""
Evaluate the epistemic uncertainty of inference results, which also called model uncertainty.
@ -159,10 +187,10 @@ class UncertaintyEvaluation:
Returns:
numpy.dtype, the epistemic uncertainty of inference results of data samples.
"""
uncertainty = self._eval_aleatoric_uncertainty(eval_data)
uncertainty = self._eval_epistemic_uncertainty(eval_data)
return uncertainty
def eval_aleatoric(self, eval_data):
def eval_aleatoric_uncertainty(self, eval_data):
"""
Evaluate the aleatoric uncertainty of inference results, which also called data uncertainty.
@ -172,7 +200,7 @@ class UncertaintyEvaluation:
Returns:
numpy.dtype, the aleatoric uncertainty of inference results of data samples.
"""
uncertainty = self._eval_epistemic_uncertainty(eval_data)
uncertainty = self._eval_aleatoric_uncertainty(eval_data)
return uncertainty

@ -107,7 +107,7 @@ if __name__ == "__main__":
# define the cvae model
cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10)
# define the loss function
net_loss = ELBO(latent_prior='Normal', output_dis='Normal')
net_loss = ELBO(latent_prior='Normal', output_prior='Normal')
# define the optimizer
optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001)
# define the training dataset

@ -95,7 +95,7 @@ if __name__ == "__main__":
# define the vae model
vae = VAE(encoder, decoder, hidden_size=400, latent_size=20)
# define the loss function
net_loss = ELBO(latent_prior='Normal', output_dis='Normal')
net_loss = ELBO(latent_prior='Normal', output_prior='Normal')
# define the optimizer
optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001)
# define the training dataset

@ -125,9 +125,11 @@ if __name__ == '__main__':
train_dataset=ds_train,
task_type='classification',
num_classes=10,
epochs=5,
uncertainty_model_path=None)
epochs=1,
epi_uncer_model_path=None,
ale_uncer_model_path=None,
save_model=False)
for eval_data in ds_eval.create_dict_iterator():
eval_data = Tensor(eval_data['image'], mstype.float32)
epistemic_uncertainty = evaluation.eval_epistemic(eval_data)
aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data)
epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)

Loading…
Cancel
Save