From d643c228d751e96379f3fe4f0999c079eaca189c Mon Sep 17 00:00:00 2001
From: zhangxinfeng3 <zhangxinfeng3@huawei.com>
Date: Mon, 17 Aug 2020 18:44:53 +0800
Subject: [PATCH] update variational and toolbox

---
 mindspore/nn/probability/dpn/vae/cvae.py      |  7 ++
 mindspore/nn/probability/dpn/vae/vae.py       |  4 +
 .../nn/probability/infer/variational/svi.py   |  9 ++
 .../toolbox/uncertainty_evaluation.py         | 87 ++++++++++---------
 tests/st/probability/test_gpu_svi_cvae.py     | 19 ++--
 tests/st/probability/test_gpu_svi_vae.py      |  7 +-
 tests/st/probability/test_gpu_vae_gan.py      | 18 ++--
 tests/st/probability/test_uncertainty.py      |  6 +-
 8 files changed, 95 insertions(+), 62 deletions(-)

diff --git a/mindspore/nn/probability/dpn/vae/cvae.py b/mindspore/nn/probability/dpn/vae/cvae.py
index 81ed36c610..ddb6279a17 100644
--- a/mindspore/nn/probability/dpn/vae/cvae.py
+++ b/mindspore/nn/probability/dpn/vae/cvae.py
@@ -52,8 +52,12 @@ class ConditionalVAE(Cell):
         super(ConditionalVAE, self).__init__()
         self.encoder = encoder
         self.decoder = decoder
+        if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
+            raise TypeError('The encoder and decoder should be Cell type.')
         self.hidden_size = check_int_positive(hidden_size)
         self.latent_size = check_int_positive(latent_size)
+        if hidden_size < latent_size:
+            raise ValueError('The latent_size should be less than or equal to the hidden_size.')
         self.num_classes = check_int_positive(num_classes)
         self.normal = C.normal
         self.exp = P.Exp()
@@ -78,6 +82,9 @@ class ConditionalVAE(Cell):
         return recon_x
 
     def construct(self, x, y):
+        """
+        The input are x and y, so the WithLossCell method needs to be rewritten when using cvae interface.
+        """
         mu, log_var = self._encode(x, y)
         std = self.exp(0.5 * log_var)
         z = self.normal(self.shape(mu), mu, std, seed=0)
diff --git a/mindspore/nn/probability/dpn/vae/vae.py b/mindspore/nn/probability/dpn/vae/vae.py
index 9e47b5d14d..3137d6a4e1 100644
--- a/mindspore/nn/probability/dpn/vae/vae.py
+++ b/mindspore/nn/probability/dpn/vae/vae.py
@@ -49,8 +49,12 @@ class VAE(Cell):
         super(VAE, self).__init__()
         self.encoder = encoder
         self.decoder = decoder
+        if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
+            raise TypeError('The encoder and decoder should be Cell type.')
         self.hidden_size = check_int_positive(hidden_size)
         self.latent_size = check_int_positive(latent_size)
+        if hidden_size < latent_size:
+            raise ValueError('The latent_size should be less than or equal to the hidden_size.')
         self.normal = C.normal
         self.exp = P.Exp()
         self.reshape = P.Reshape()
diff --git a/mindspore/nn/probability/infer/variational/svi.py b/mindspore/nn/probability/infer/variational/svi.py
index 8aca1221ac..f9c1b96f21 100644
--- a/mindspore/nn/probability/infer/variational/svi.py
+++ b/mindspore/nn/probability/infer/variational/svi.py
@@ -15,7 +15,10 @@
 """Stochastic Variational Inference(SVI)."""
 import mindspore.common.dtype as mstype
 from mindspore.common.tensor import Tensor
+from mindspore._checkparam import check_int_positive
+from ....cell import Cell
 from ....wrap.cell_wrapper import TrainOneStepCell
+from .elbo import ELBO
 
 
 class SVI:
@@ -35,7 +38,12 @@ class SVI:
 
     def __init__(self, net_with_loss, optimizer):
         self.net_with_loss = net_with_loss
+        self.loss_fn = getattr(net_with_loss, '_loss_fn')
+        if not isinstance(self.loss_fn, ELBO):
+            raise TypeError('The loss function for variational inference should be ELBO.')
         self.optimizer = optimizer
+        if not isinstance(optimizer, Cell):
+            raise TypeError('The optimizer should be Cell type.')
         self._loss = 0.0
 
     def run(self, train_dataset, epochs=10):
@@ -49,6 +57,7 @@ class SVI:
         Outputs:
             Cell, the trained probability network.
         """
+        epochs = check_int_positive(epochs)
         train_net = TrainOneStepCell(self.net_with_loss, self.optimizer)
         train_net.set_train()
         for _ in range(1, epochs+1):
diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py
index a61b19cd13..4467eb9b2a 100644
--- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py
+++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py
@@ -15,7 +15,7 @@
 """Toolbox for Uncertainty Evaluation."""
 import numpy as np
 
-from mindspore._checkparam import check_int_positive
+from mindspore._checkparam import check_int_positive, check_bool
 from mindspore.ops import composite as C
 from mindspore.ops import operations as P
 from mindspore.train import Model
@@ -36,7 +36,8 @@ class UncertaintyEvaluation:
 
     Args:
         model (Cell): The model for uncertainty evaluation.
-        train_dataset (Dataset): A dataset iterator.
+        epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty.
+        ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
         task_type (str): Option for the task types of model
             - regression: A regression model.
             - classification: A classification model.
@@ -55,9 +56,11 @@ class UncertaintyEvaluation:
         >>> network = LeNet()
         >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
         >>> load_param_into_net(network, param_dict)
-        >>> ds_train = create_dataset('workspace/mnist/train')
+        >>> epi_ds_train = create_dataset('workspace/mnist/train')
+        >>> ale_ds_train = create_dataset('workspace/mnist/train')
         >>> evaluation = UncertaintyEvaluation(model=network,
-        >>>                                    train_dataset=ds_train,
+        >>>                                    epi_train_dataset=epi_ds_train,
+        >>>                                    ale_train_dataset=ale_ds_train,
         >>>                                    task_type='classification',
         >>>                                    num_classes=10,
         >>>                                    epochs=1,
@@ -68,28 +71,30 @@ class UncertaintyEvaluation:
         >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
     """
 
-    def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
+    def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1,
                  epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
-        self.model = model
-        self.train_dataset = train_dataset
+        self.epi_model = model
+        self.ale_model = model
+        self.epi_train_dataset = epi_train_dataset
+        self.ale_train_dataset = ale_train_dataset
         self.task_type = task_type
-        self.num_classes = check_int_positive(num_classes)
-        self.epochs = epochs
+        self.epochs = check_int_positive(epochs)
         self.epi_uncer_model_path = epi_uncer_model_path
         self.ale_uncer_model_path = ale_uncer_model_path
-        self.save_model = save_model
+        self.save_model = check_bool(save_model)
         self.epi_uncer_model = None
         self.ale_uncer_model = None
         self.concat = P.Concat(axis=0)
         self.sum = P.ReduceSum()
         self.pow = P.Pow()
-        if self.task_type not in ('regression', 'classification'):
+        if not isinstance(model, Cell):
+            raise TypeError('The model should be Cell type.')
+        if task_type not in ('regression', 'classification'):
             raise ValueError('The task should be regression or classification.')
-        if self.task_type == 'classification':
-            if self.num_classes is None:
-                raise ValueError("Classification task needs to input labels.")
-        if self.save_model:
-            if self.epi_uncer_model_path is None or self.ale_uncer_model_path is None:
+        if task_type == 'classification':
+            self.num_classes = check_int_positive(num_classes)
+        if save_model:
+            if epi_uncer_model_path is None or ale_uncer_model_path is None:
                 raise ValueError("If save_model is True, the epi_uncer_model_path and "
                                  "ale_uncer_model_path should not be None.")
 
@@ -102,7 +107,7 @@ class UncertaintyEvaluation:
         Get the model which can obtain the epistemic uncertainty.
         """
         if self.epi_uncer_model is None:
-            self.epi_uncer_model = EpistemicUncertaintyModel(self.model)
+            self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
             if self.epi_uncer_model.drop_count == 0:
                 if self.task_type == 'classification':
                     net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
@@ -117,9 +122,9 @@ class UncertaintyEvaluation:
                     ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
                                                  directory=self.epi_uncer_model_path,
                                                  config=config_ck)
-                    model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
+                    model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
                 elif self.epi_uncer_model_path is None:
-                    model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
+                    model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()])
                 else:
                     uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
                     load_param_into_net(self.epi_uncer_model, uncer_param_dict)
@@ -148,7 +153,7 @@ class UncertaintyEvaluation:
         Get the model which can obtain the aleatoric uncertainty.
         """
         if self.ale_uncer_model is None:
-            self.ale_uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
+            self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
             net_loss = AleatoricLoss(self.task_type)
             net_opt = Adam(self.ale_uncer_model.trainable_params())
             if self.task_type == 'classification':
@@ -160,9 +165,9 @@ class UncertaintyEvaluation:
                 ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
                                              directory=self.ale_uncer_model_path,
                                              config=config_ck)
-                model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
+                model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
             elif self.ale_uncer_model_path is None:
-                model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
+                model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()])
             else:
                 uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
                 load_param_into_net(self.ale_uncer_model, uncer_param_dict)
@@ -216,31 +221,31 @@ class EpistemicUncertaintyModel(Cell):
     <https://arxiv.org/abs/1506.02142>`.
     """
 
-    def __init__(self, model):
+    def __init__(self, epi_model):
         super(EpistemicUncertaintyModel, self).__init__()
         self.drop_count = 0
-        self.model = self._make_epistemic(model)
+        self.epi_model = self._make_epistemic(epi_model)
 
     def construct(self, x):
-        x = self.model(x)
+        x = self.epi_model(x)
         return x
 
-    def _make_epistemic(self, model, dropout_rate=0.5):
+    def _make_epistemic(self, epi_model, dropout_rate=0.5):
         """
         The dropout rate is set to 0.5 by default.
         """
-        for (name, layer) in model.name_cells().items():
+        for (name, layer) in epi_model.name_cells().items():
             if isinstance(layer, Dropout):
                 self.drop_count += 1
-                return model
-        for (name, layer) in model.name_cells().items():
+            return epi_model
+        for (name, layer) in epi_model.name_cells().items():
             if isinstance(layer, (Conv2d, Dense)):
                 uncertainty_layer = layer
                 uncertainty_name = name
                 drop = Dropout(keep_prob=dropout_rate)
                 bnn_drop = SequentialCell([uncertainty_layer, drop])
-                setattr(model, uncertainty_name, bnn_drop)
-                return model
+                setattr(epi_model, uncertainty_name, bnn_drop)
+            return epi_model
         raise ValueError("The model has not Dense Layer or Convolution Layer, "
                          "it can not evaluate epistemic uncertainty so far.")
 
@@ -254,40 +259,40 @@ class AleatoricUncertaintyModel(Cell):
     <https://arxiv.org/abs/1703.04977>`.
     """
 
-    def __init__(self, model, labels, task):
+    def __init__(self, ale_model, num_classes, task):
         super(AleatoricUncertaintyModel, self).__init__()
         self.task = task
         if task == 'classification':
-            self.model = model
-            self.var_layer = Dense(labels, labels)
+            self.ale_model = ale_model
+            self.var_layer = Dense(num_classes, num_classes)
         else:
-            self.model, self.var_layer, self.pred_layer = self._make_aleatoric(model)
+            self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(ale_model)
 
     def construct(self, x):
         if self.task == 'classification':
-            pred = self.model(x)
+            pred = self.ale_model(x)
             var = self.var_layer(pred)
         else:
-            x = self.model(x)
+            x = self.ale_model(x)
             pred = self.pred_layer(x)
             var = self.var_layer(x)
         return pred, var
 
-    def _make_aleatoric(self, model):
+    def _make_aleatoric(self, ale_model):
         """
         In order to add variance into original loss, add var Layer after the original network.
         """
         dense_layer = dense_name = None
-        for (name, layer) in model.name_cells().items():
+        for (name, layer) in ale_model.name_cells().items():
             if isinstance(layer, Dense):
                 dense_layer = layer
                 dense_name = name
         if dense_layer is None:
             raise ValueError("The model has not Dense Layer, "
                              "it can not evaluate aleatoric uncertainty so far.")
-        setattr(model, dense_name, Flatten())
+        setattr(ale_model, dense_name, Flatten())
         var_layer = Dense(dense_layer.in_channels, dense_layer.out_channels)
-        return model, var_layer, dense_layer
+        return ale_model, var_layer, dense_layer
 
 
 class AleatoricLoss(Cell):
diff --git a/tests/st/probability/test_gpu_svi_cvae.py b/tests/st/probability/test_gpu_svi_cvae.py
index 44f6c040fa..aefd27c675 100644
--- a/tests/st/probability/test_gpu_svi_cvae.py
+++ b/tests/st/probability/test_gpu_svi_cvae.py
@@ -60,12 +60,10 @@ class Decoder(nn.Cell):
         return z
 
 
-class WithLossCell(nn.Cell):
-    def __init__(self, backbone, loss_fn):
-        super(WithLossCell, self).__init__(auto_prefix=False)
-        self._backbone = backbone
-        self._loss_fn = loss_fn
-
+class CVAEWithLossCell(nn.WithLossCell):
+    """
+    Rewrite WithLossCell for CVAE
+    """
     def construct(self, data, label):
         out = self._backbone(data, label)
         return self._loss_fn(out, label)
@@ -100,7 +98,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
     return mnist_ds
 
 
-if __name__ == "__main__":
+def test_svi_cave():
     # define the encoder and decoder
     encoder = Encoder(num_classes=10)
     decoder = Decoder()
@@ -113,11 +111,11 @@ if __name__ == "__main__":
     # define the training dataset
     ds_train = create_dataset(image_path, 128, 1)
     # define the WithLossCell modified
-    net_with_loss = WithLossCell(cvae, net_loss)
+    net_with_loss = CVAEWithLossCell(cvae, net_loss)
     # define the variational inference
     vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
     # run the vi to return the trained network.
-    cvae = vi.run(train_dataset=ds_train, epochs=10)
+    cvae = vi.run(train_dataset=ds_train, epochs=5)
     # get the trained loss
     trained_loss = vi.get_train_loss()
     # test function: generate_sample
@@ -128,3 +126,6 @@ if __name__ == "__main__":
         sample_x = Tensor(sample['image'], dtype=mstype.float32)
         sample_y = Tensor(sample['label'], dtype=mstype.int32)
         reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y)
+    print('The loss of the trained network is ', trained_loss)
+    print('The shape of the generated sample is ', generated_sample.shape)
+    print('The shape of the reconstructed sample is ', reconstructed_sample.shape)
diff --git a/tests/st/probability/test_gpu_svi_vae.py b/tests/st/probability/test_gpu_svi_vae.py
index a175a4ae4c..6e1b23ee1b 100644
--- a/tests/st/probability/test_gpu_svi_vae.py
+++ b/tests/st/probability/test_gpu_svi_vae.py
@@ -88,7 +88,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
     return mnist_ds
 
 
-if __name__ == "__main__":
+def test_svi_vae():
     # define the encoder and decoder
     encoder = Encoder()
     decoder = Decoder()
@@ -104,7 +104,7 @@ if __name__ == "__main__":
     # define the variational inference
     vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
     # run the vi to return the trained network.
-    vae = vi.run(train_dataset=ds_train, epochs=10)
+    vae = vi.run(train_dataset=ds_train, epochs=5)
     # get the trained loss
     trained_loss = vi.get_train_loss()
     # test function: generate_sample
@@ -113,3 +113,6 @@ if __name__ == "__main__":
     for sample in ds_train.create_dict_iterator():
         sample_x = Tensor(sample['image'], dtype=mstype.float32)
         reconstructed_sample = vae.reconstruct_sample(sample_x)
+    print('The loss of the trained network is ', trained_loss)
+    print('The hape of the generated sample is ', generated_sample.shape)
+    print('The shape of the reconstructed sample is ', reconstructed_sample.shape)
diff --git a/tests/st/probability/test_gpu_vae_gan.py b/tests/st/probability/test_gpu_vae_gan.py
index b4f62d10e8..adf0927569 100644
--- a/tests/st/probability/test_gpu_vae_gan.py
+++ b/tests/st/probability/test_gpu_vae_gan.py
@@ -22,6 +22,7 @@ import mindspore.dataset.transforms.vision.c_transforms as CV
 import mindspore.nn as nn
 from mindspore import context
 from mindspore.ops import operations as P
+from mindspore.ops import composite as C
 from mindspore.nn.probability.dpn import VAE
 from mindspore.nn.probability.infer import ELBO, SVI
 
@@ -93,17 +94,18 @@ class VaeGan(nn.Cell):
         self.dense = nn.Dense(20, 400)
         self.vae = VAE(self.E, self.G, 400, 20)
         self.shape = P.Shape()
+        self.normal = C.normal
         self.to_tensor = P.ScalarToArray()
 
     def construct(self, x):
-        recon_x, x, mu, std, z, prior = self.vae(x)
-        z_p = prior('sample', self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0))
+        recon_x, x, mu, std = self.vae(x)
+        z_p = self.normal(self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
         z_p = self.dense(z_p)
         x_p = self.G(z_p)
         ld_real = self.D(x)
         ld_fake = self.D(recon_x)
         ld_p = self.D(x_p)
-        return ld_real, ld_fake, ld_p, recon_x, x, mu, std, z, prior
+        return ld_real, ld_fake, ld_p, recon_x, x, mu, std
 
 
 class VaeGanLoss(nn.Cell):
@@ -111,13 +113,13 @@ class VaeGanLoss(nn.Cell):
         super(VaeGanLoss, self).__init__()
         self.zeros = P.ZerosLike()
         self.mse = nn.MSELoss(reduction='sum')
-        self.elbo = ELBO(latent_prior='Normal', output_dis='Normal')
+        self.elbo = ELBO(latent_prior='Normal', output_prior='Normal')
 
     def construct(self, data, label):
-        ld_real, ld_fake, ld_p, recon_x, x, mean, std, z, prior = data
+        ld_real, ld_fake, ld_p, recon_x, x, mean, std = data
         y_real = self.zeros(ld_real) + 1
         y_fake = self.zeros(ld_fake)
-        elbo_data = (recon_x, x, mean, std, z, prior)
+        elbo_data = (recon_x, x, mean, std)
         loss_D = self.mse(ld_real, y_real)
         loss_GD = self.mse(ld_p, y_fake)
         loss_G = self.mse(ld_fake, y_real)
@@ -154,11 +156,11 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
     return mnist_ds
 
 
-if __name__ == "__main__":
+def test_vae_gan():
     vae_gan = VaeGan()
     net_loss = VaeGanLoss()
     optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001)
     ds_train = create_dataset(image_path, 128, 1)
     net_with_loss = nn.WithLossCell(vae_gan, net_loss)
     vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
-    vae_gan = vi.run(train_dataset=ds_train, epochs=10)
+    vae_gan = vi.run(train_dataset=ds_train, epochs=5)
diff --git a/tests/st/probability/test_uncertainty.py b/tests/st/probability/test_uncertainty.py
index 92850141eb..ed3f45883f 100644
--- a/tests/st/probability/test_uncertainty.py
+++ b/tests/st/probability/test_uncertainty.py
@@ -119,10 +119,12 @@ if __name__ == '__main__':
     param_dict = load_checkpoint('checkpoint_lenet.ckpt')
     load_param_into_net(network, param_dict)
     # get train and eval dataset
-    ds_train = create_dataset('workspace/mnist/train')
+    epi_ds_train = create_dataset('workspace/mnist/train')
+    ale_ds_train = create_dataset('workspace/mnist/train')
     ds_eval = create_dataset('workspace/mnist/test')
     evaluation = UncertaintyEvaluation(model=network,
-                                       train_dataset=ds_train,
+                                       epi_train_dataset=epi_ds_train,
+                                       ale_train_dataset=ale_ds_train,
                                        task_type='classification',
                                        num_classes=10,
                                        epochs=1,