You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/st/probability/zhusuan/vae/vae_mnist.py

166 lines
5.4 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" VAE """
import os
import numpy as np
from utils import create_dataset, save_img
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
import zhusuan as zs
class ReduceMeanLoss(nn.L1Loss):
def construct(self, base, target):
# return self.get_loss(x)
return base
class Generator(zs.BayesianNet):
""" Generator """
def __init__(self, x_dim, z_dim, batch_size):
super().__init__()
self.x_dim = x_dim
self.z_dim = z_dim
self.batch_size = batch_size
self.fc1 = nn.Dense(z_dim, 500)
self.act1 = nn.ReLU()
self.fc2 = nn.Dense(500, 500)
self.act2 = nn.ReLU()
self.fc3 = nn.Dense(500, x_dim)
self.fill = P.Fill()
self.sigmoid = P.Sigmoid()
self.reshape_op = P.Reshape()
def ones(self, shape):
return self.fill(mstype.float32, shape, 1.)
def zeros(self, shape):
return self.fill(mstype.float32, shape, 0.)
def construct(self, x, z, y):
""" construct """
assert y is None ## we have no conditional information
if not x is None:
x = self.reshape_op(x, (32, 32*32))
z_mean = self.zeros((self.batch_size, self.z_dim))
z_std = self.ones((self.batch_size, self.z_dim))
z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=False)
x_mean = self.sigmoid(self.fc3(self.act2(self.fc2(self.act1(self.fc1(z))))))
if x is None:
#x = self.bernoulli_dist('sample', (), x_mean)
x = x_mean
x, log_prob_x = self.Bernoulli('data', observation=x, shape=(), probs=x_mean)
return x, log_prob_x, z, log_prob_z
class Variational(zs.BayesianNet):
""" Variational """
def __init__(self, x_dim, z_dim, batch_size):
super().__init__()
self.x_dim = x_dim
self.z_dim = z_dim
self.batch_size = batch_size
self.reshape_op = P.Reshape()
self.fc1 = nn.Dense(x_dim, 500)
self.act1 = nn.ReLU()
self.fc2 = nn.Dense(500, 500)
self.act2 = nn.ReLU()
self.fc3 = nn.Dense(500, z_dim)
self.fc4 = nn.Dense(500, z_dim)
self.fill = P.Fill()
self.exp = P.Exp()
def ones(self, shape):
return self.fill(mstype.float32, shape, 1.)
def zeros(self, shape):
return self.fill(mstype.float32, shape, 0.)
def construct(self, x, z, y):
""" construct """
assert y is None ## we have no conditional information
x = self.reshape_op(x, (32, 32*32))
z_logit = self.act2(self.fc2(self.act1(self.fc1(x))))
z_mean = self.fc3(z_logit)
z_std = self.exp(self.fc4(z_logit))
#z, log_prob_z = self.reparameterization(z_mean, z_std)
z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=True)
return z, log_prob_z
def main():
# We currently support pynative mode with device GPU
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
epoch_size = 1
batch_size = 32
mnist_path = "/data/chengzi/zhusuan-mindspore/data/MNIST"
repeat_size = 1
# Define model parameters
z_dim = 40
x_dim = 32*32
# create the network
generator = Generator(x_dim, z_dim, batch_size)
variational = Variational(x_dim, z_dim, batch_size)
network = zs.variational.ELBO(generator, variational)
# define loss
# learning rate setting
lr = 0.001
net_loss = ReduceMeanLoss()
# define the optimizer
print(network.trainable_params()[0])
net_opt = nn.Adam(network.trainable_params(), lr)
model = Model(network, net_loss, net_opt)
ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size, repeat_size)
model.train(epoch_size, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
print(network.trainable_params()[0])
iterator = ds_train.create_tuple_iterator()
for item in iterator:
batch_x = item[0].reshape(32, 32*32)
break
z, _ = network.variational(Tensor(batch_x), None, None)
sample, _, _, _ = network.generator(None, z, None)
sample = sample.asnumpy()
save_img(batch_x, 'result/origin_x.png')
save_img(sample, 'result/reconstruct_x.png')
for i in range(4):
sample, _, _, _ = network.generator(None, None, None)
sample = sample.asnumpy()
samples = sample if i == 0 else np.concatenate([samples, sample], axis=0)
save_img(samples, 'result/sample_x.png', num=4*batch_size)
if __name__ == '__main__':
main()