!4194 Add infer and dpn to nn/probability

Merge pull request !4194 from zhangxinfeng3/master
pull/4194/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 1b63c76c61

@ -20,3 +20,5 @@ The high-level components used to construct the probabilistic network.
from . import bijector
from . import distribution
from . import infer
from . import dpn

@ -0,0 +1,24 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Deep Probability Network(dpn).
Deep probability network such as BNN and VAE network.
"""
from .vae import *
__all__ = []
__all__.extend(vae.__all__)

@ -0,0 +1,25 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Variational auto-encoder (VAE).
The interface of VAE, which allows to construct probablity model like DNN model.
"""
from .vae import VAE
from .cvae import ConditionalVAE
__all__ = ['VAE',
'ConditionalVAE']

@ -0,0 +1,127 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Conditional Variational auto-encoder (CVAE)."""
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive
from ...distribution.normal import Normal
from ....cell import Cell
from ....layer.basic import Dense, OneHot
class ConditionalVAE(Cell):
r"""
Conditional Variational auto-encoder (CVAE).
The difference with VAE is that CVAE uses labels information.
see more details in `<http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-
conditional-generative-models>`.
Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor
should be math:`(N, hidden_size)`.
The latent_size should be less than or equal to the hidden_size.
Args:
encoder(Cell): The DNN model defined as encoder.
decoder(Cell): The DNN model defined as decoder.
hidden_size(int): The size of encoder's output tensor.
latent_size(int): The size of the latent space.
num_classes(int): The number of classes.
Inputs:
- **input_x** (Tensor) - the same shape as the input of encoder.
- **input_y** (Tensor) - the tensor of the target data, the shape is math:`(N, 1)`.
Outputs:
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)).
"""
def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes):
super(ConditionalVAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.hidden_size = check_int_positive(hidden_size)
self.latent_size = check_int_positive(latent_size)
self.num_classes = check_int_positive(num_classes)
self.normal = C.normal
self.exp = P.Exp()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=1)
self.to_tensor = P.ScalarToArray()
self.normal_dis = Normal()
self.one_hot = OneHot(depth=num_classes)
self.standard_normal_dis = Normal([0] * self.latent_size, [1] * self.latent_size)
self.dense1 = Dense(self.hidden_size, self.latent_size)
self.dense2 = Dense(self.hidden_size, self.latent_size)
self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size)
def _encode(self, x, y):
en_x = self.encoder(x, y)
mu = self.dense1(en_x)
log_var = self.dense2(en_x)
return mu, log_var
def _decode(self, z):
z = self.dense3(z)
recon_x = self.decoder(z)
return recon_x
def construct(self, x, y):
mu, log_var = self._encode(x, y)
std = self.exp(0.5 * log_var)
z = self.normal_dis('sample', mean=mu, sd=std)
y = self.one_hot(y)
z_c = self.concat((z, y))
recon_x = self._decode(z_c)
return recon_x, x, mu, std, z, self.standard_normal_dis
def generate_sample(self, sample_y, generate_nums=None, shape=None):
"""
Randomly sample from latent space to generate sample.
Args:
sample_y (Tensor): Define the label of sample, int tensor.
generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`.
Returns:
Tensor, the generated sample.
"""
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
sample_y = self.one_hot(sample_y)
sample_c = self.concat((sample_z, sample_y))
sample = self._decode(sample_c)
sample = self.reshape(sample, shape)
return sample
def reconstruct_sample(self, x, y):
"""
Reconstruct sample from original data.
Args:
x (Tensor): The input tensor to be reconstructed.
y (Tensor): The label of the input tensor.
Returns:
Tensor, the reconstructed sample.
"""
mu, log_var = self._encode(x, y)
std = self.exp(0.5 * log_var)
z = self.normal(mu.shape, mu, std, seed=0)
y = self.one_hot(y)
z_c = self.concat((z, y))
recon_x = self._decode(z_c)
return recon_x

@ -0,0 +1,113 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Variational auto-encoder (VAE)"""
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive
from ...distribution.normal import Normal
from ....cell import Cell
from ....layer.basic import Dense
class VAE(Cell):
r"""
Variational auto-encoder (VAE).
The VAE defines a generative model, `Z` is sampled from the prior, then used to reconstruct `X` by a decoder.
see more details in `Auto-Encoding Variational Bayes<https://arxiv.org/abs/1312.6114>`_.
Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor
should be math:`(N, hidden_size)`.
The latent_size should be less than or equal to the hidden_size.
Args:
encoder(Cell): The DNN model defined as encoder.
decoder(Cell): The DNN model defined as decoder.
hidden_size(int): The size of encoder's output tensor.
latent_size(int): The size of the latent space.
Inputs:
- **input** (Tensor) - the same shape as the input of encoder.
Outputs:
- **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)).
"""
def __init__(self, encoder, decoder, hidden_size, latent_size):
super(VAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.hidden_size = check_int_positive(hidden_size)
self.latent_size = check_int_positive(latent_size)
self.normal = C.normal
self.exp = P.Exp()
self.reshape = P.Reshape()
self.to_tensor = P.ScalarToArray()
self.normal_dis = Normal()
self.standard_normal_dis = Normal([0]*self.latent_size, [1]*self.latent_size)
self.dense1 = Dense(self.hidden_size, self.latent_size)
self.dense2 = Dense(self.hidden_size, self.latent_size)
self.dense3 = Dense(self.latent_size, self.hidden_size)
def _encode(self, x):
en_x = self.encoder(x)
mu = self.dense1(en_x)
log_var = self.dense2(en_x)
return mu, log_var
def _decode(self, z):
z = self.dense3(z)
recon_x = self.decoder(z)
return recon_x
def construct(self, x):
mu, log_var = self._encode(x)
std = self.exp(0.5 * log_var)
z = self.normal_dis('sample', mean=mu, sd=std)
recon_x = self._decode(z)
return recon_x, x, mu, std, z, self.standard_normal_dis
def generate_sample(self, generate_nums, shape):
"""
Randomly sample from latent space to generate sample.
Args:
generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`.
Returns:
Tensor, the generated sample.
"""
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
sample = self._decode(sample_z)
sample = self.reshape(sample, shape)
return sample
def reconstruct_sample(self, x):
"""
Reconstruct sample from original data.
Args:
x (Tensor): The input tensor to be reconstructed.
Returns:
Tensor, the reconstructed sample.
"""
mu, log_var = self._encode(x)
std = self.exp(0.5 * log_var)
z = self.normal(mu.shape, mu, std, seed=0)
recon_x = self._decode(z)
return recon_x

@ -0,0 +1,22 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Infer algorithms in Probabilistic Programming.
"""
from .variational import *
__all__ = []
__all__.extend(variational.__all__)

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
SVI and ELBO.
The SVI interface is for variational inference.
The ELBO interface is called as loss while model training.
"""
from .svi import SVI
from .elbo import ELBO
__all__ = ['SVI',
'ELBO']

@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The Evidence Lower Bound (ELBO)."""
from mindspore.ops import operations as P
from ...distribution.normal import Normal
from ....cell import Cell
from ....loss.loss import MSELoss
class ELBO(Cell):
r"""
The Evidence Lower Bound (ELBO).
Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to
the posterior distribution. It maximizes the evidence lower bound (ELBO), a lower bound on the logarithm of
the marginal probability of the observations log p(x). The ELBO is equal to the negative KL divergence up to
an additive constant.
see more details in `Variational Inference: A Review for Statisticians<https://arxiv.org/abs/1601.00670>`_.
Args:
latent_prior(str): The prior distribution of latent space. Default: Normal.
- Normal: The prior distribution of latent space is Normal.
output_prior(str): The distribution of output data. Default: Normal.
- Normal: If the distribution of output data is Normal, the reconstruct loss is MSELoss.
Inputs:
- **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)).
- **target_data** (Tensor) - the target tensor.
Outputs:
Tensor, loss float tensor.
"""
def __init__(self, latent_prior='Normal', output_prior='Normal'):
super(ELBO, self).__init__()
self.sum = P.ReduceSum()
if latent_prior == 'Normal':
self.posterior = Normal()
else:
raise ValueError('The values of latent_prior now only support Normal')
if output_prior == 'Normal':
self.recon_loss = MSELoss(reduction='sum')
else:
raise ValueError('The values of output_dis now only support Normal')
def construct(self, data, label):
recon_x, x, mu, std, z, prior = data
reconstruct_loss = self.recon_loss(x, recon_x)
kl_loss = -(prior('log_prob', z) - self.posterior('log_prob', z, mu, std)) \
* self.posterior('prob', z, mu, std)
elbo = reconstruct_loss + self.sum(kl_loss)
return elbo

@ -0,0 +1,72 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Stochastic Variational Inference(SVI)."""
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from ....wrap.cell_wrapper import TrainOneStepCell
class SVI:
r"""
Stochastic Variational Inference(SVI).
Variational inference casts the inference problem as an optimization. Some distributions over the hidden
variables that is indexed by a set of free parameters, and then optimize the parameters to make it closest to
the posterior of interest.
see more details in `Variational Inference: A Review for Statisticians<https://arxiv.org/abs/1601.00670>`_.
Args:
net_with_loss(Cell): Cell with loss function.
optimizer (Cell): Optimizer for updating the weights.
"""
def __init__(self, net_with_loss, optimizer):
self.net_with_loss = net_with_loss
self.optimizer = optimizer
self._loss = 0.0
def run(self, train_dataset, epochs=10):
"""
Optimize the parameters by training the probability network, and return the trained network.
Args:
epochs (int): Total number of iterations on the data. Default: 10.
train_dataset (Dataset): A training dataset iterator.
Outputs:
Cell, the trained probability network.
"""
train_net = TrainOneStepCell(self.net_with_loss, self.optimizer)
train_net.set_train()
for _ in range(1, epochs+1):
train_loss = 0
dataset_size = 0
for data in train_dataset.create_dict_iterator():
x = Tensor(data['image'], dtype=mstype.float32)
y = Tensor(data['label'], dtype=mstype.int32)
dataset_size += len(x)
loss = train_net(x, y).asnumpy()
train_loss += loss
self._loss = train_loss / dataset_size
model = self.net_with_loss.backbone_network
return model
def get_train_loss(self):
"""
Returns:
numpy.dtype, the loss after training.
"""
return self._loss

@ -0,0 +1,130 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.ops import operations as P
from mindspore.nn.probability.dpn import ConditionalVAE
from mindspore.nn.probability.infer import ELBO, SVI
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
IMAGE_SHAPE = (-1, 1, 32, 32)
image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train")
class Encoder(nn.Cell):
def __init__(self, num_classes):
super(Encoder, self).__init__()
self.fc1 = nn.Dense(1024 + num_classes, 400)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
self.concat = P.Concat(axis=1)
self.one_hot = nn.OneHot(depth=num_classes)
def construct(self, x, y):
x = self.flatten(x)
y = self.one_hot(y)
input_x = self.concat((x, y))
input_x = self.fc1(input_x)
input_x = self.relu(input_x)
return input_x
class Decoder(nn.Cell):
def __init__(self):
super(Decoder, self).__init__()
self.fc2 = nn.Dense(400, 1024)
self.sigmoid = nn.Sigmoid()
self.reshape = P.Reshape()
def construct(self, z):
z = self.fc2(z)
z = self.reshape(z, IMAGE_SHAPE)
z = self.sigmoid(z)
return z
class WithLossCell(nn.Cell):
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, data, label):
out = self._backbone(data, label)
return self._loss_fn(out, label)
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.batch(batch_size)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
if __name__ == "__main__":
# define the encoder and decoder
encoder = Encoder(num_classes=10)
decoder = Decoder()
# define the cvae model
cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10)
# define the loss function
net_loss = ELBO(latent_prior='Normal', output_dis='Normal')
# define the optimizer
optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001)
# define the training dataset
ds_train = create_dataset(image_path, 128, 1)
# define the WithLossCell modified
net_with_loss = WithLossCell(cvae, net_loss)
# define the variational inference
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
# run the vi to return the trained network.
cvae = vi.run(train_dataset=ds_train, epochs=10)
# get the trained loss
trained_loss = vi.get_train_loss()
# test function: generate_sample
sample_label = Tensor([i for i in range(0, 8)] * 8, dtype=mstype.int32)
generated_sample = cvae.generate_sample(sample_label, 64, IMAGE_SHAPE)
# test function: reconstruct_sample
for sample in ds_train.create_dict_iterator():
sample_x = Tensor(sample['image'], dtype=mstype.float32)
sample_y = Tensor(sample['label'], dtype=mstype.int32)
reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y)

@ -0,0 +1,115 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.ops import operations as P
from mindspore.nn.probability.dpn import VAE
from mindspore.nn.probability.infer import ELBO, SVI
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
IMAGE_SHAPE = (-1, 1, 32, 32)
image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train")
class Encoder(nn.Cell):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Dense(1024, 800)
self.fc2 = nn.Dense(800, 400)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
return x
class Decoder(nn.Cell):
def __init__(self):
super(Decoder, self).__init__()
self.fc1 = nn.Dense(400, 1024)
self.sigmoid = nn.Sigmoid()
self.reshape = P.Reshape()
def construct(self, z):
z = self.fc1(z)
z = self.reshape(z, IMAGE_SHAPE)
z = self.sigmoid(z)
return z
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.batch(batch_size)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
if __name__ == "__main__":
# define the encoder and decoder
encoder = Encoder()
decoder = Decoder()
# define the vae model
vae = VAE(encoder, decoder, hidden_size=400, latent_size=20)
# define the loss function
net_loss = ELBO(latent_prior='Normal', output_dis='Normal')
# define the optimizer
optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001)
# define the training dataset
ds_train = create_dataset(image_path, 128, 1)
net_with_loss = nn.WithLossCell(vae, net_loss)
# define the variational inference
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
# run the vi to return the trained network.
vae = vi.run(train_dataset=ds_train, epochs=10)
# get the trained loss
trained_loss = vi.get_train_loss()
# test function: generate_sample
generated_sample = vae.generate_sample(64, IMAGE_SHAPE)
# test function: reconstruct_sample
for sample in ds_train.create_dict_iterator():
sample_x = Tensor(sample['image'], dtype=mstype.float32)
reconstructed_sample = vae.reconstruct_sample(sample_x)

@ -0,0 +1,164 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The VAE interface can be called to construct VAE-GAN network.
"""
import os
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn.probability.dpn import VAE
from mindspore.nn.probability.infer import ELBO, SVI
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
IMAGE_SHAPE = (-1, 1, 32, 32)
image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train")
class Encoder(nn.Cell):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Dense(1024, 400)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
return x
class Decoder(nn.Cell):
def __init__(self):
super(Decoder, self).__init__()
self.fc1 = nn.Dense(400, 1024)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.reshape = P.Reshape()
def construct(self, z):
z = self.fc1(z)
z = self.reshape(z, IMAGE_SHAPE)
z = self.sigmoid(z)
return z
class Discriminator(nn.Cell):
"""
The Discriminator of the GAN network.
"""
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Dense(1024, 400)
self.fc2 = nn.Dense(400, 720)
self.fc3 = nn.Dense(720, 1024)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.sigmoid(x)
return x
class VaeGan(nn.Cell):
def __init__(self):
super(VaeGan, self).__init__()
self.E = Encoder()
self.G = Decoder()
self.D = Discriminator()
self.dense = nn.Dense(20, 400)
self.vae = VAE(self.E, self.G, 400, 20)
self.shape = P.Shape()
self.to_tensor = P.ScalarToArray()
def construct(self, x):
recon_x, x, mu, std, z, prior = self.vae(x)
z_p = prior('sample', self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0))
z_p = self.dense(z_p)
x_p = self.G(z_p)
ld_real = self.D(x)
ld_fake = self.D(recon_x)
ld_p = self.D(x_p)
return ld_real, ld_fake, ld_p, recon_x, x, mu, std, z, prior
class VaeGanLoss(nn.Cell):
def __init__(self):
super(VaeGanLoss, self).__init__()
self.zeros = P.ZerosLike()
self.mse = nn.MSELoss(reduction='sum')
self.elbo = ELBO(latent_prior='Normal', output_dis='Normal')
def construct(self, data, label):
ld_real, ld_fake, ld_p, recon_x, x, mean, std, z, prior = data
y_real = self.zeros(ld_real) + 1
y_fake = self.zeros(ld_fake)
elbo_data = (recon_x, x, mean, std, z, prior)
loss_D = self.mse(ld_real, y_real)
loss_GD = self.mse(ld_p, y_fake)
loss_G = self.mse(ld_fake, y_real)
elbo_loss = self.elbo(elbo_data, label)
return loss_D + loss_G + loss_GD + elbo_loss
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.batch(batch_size)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
if __name__ == "__main__":
vae_gan = VaeGan()
net_loss = VaeGanLoss()
optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001)
ds_train = create_dataset(image_path, 128, 1)
net_with_loss = nn.WithLossCell(vae_gan, net_loss)
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
vae_gan = vi.run(train_dataset=ds_train, epochs=10)

@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test VAE interface """
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _executor
from mindspore.nn.probability.dpn import VAE
class Encoder(nn.Cell):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Dense(6, 3)
self.relu = nn.ReLU()
def construct(self, x):
x = self.fc1(x)
x = self.relu(x)
return x
class Decoder(nn.Cell):
def __init__(self):
super(Decoder, self).__init__()
self.fc1 = nn.Dense(3, 6)
self.sigmoid = nn.Sigmoid()
def construct(self, z):
z = self.fc1(z)
z = self.sigmoid(z)
return z
def test_vae():
"""
Test the vae interface with the DNN model.
"""
encoder = Encoder()
decoder = Decoder()
net = VAE(encoder, decoder, hidden_size=3, latent_size=2)
input_data = Tensor(np.random.rand(32, 6), dtype=mstype.float32)
_executor.compile(net, input_data)
Loading…
Cancel
Save