!6551 Pull request from Tsinghua Zhusuan Team
Merge pull request !6551 from mcgrady00h/mindspore-zhusuanpull/6551/MERGE
commit
041d9a65e4
@ -0,0 +1,18 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Zhusuan package: a probalistic programming library """
|
||||
|
||||
from .framework import *
|
||||
from .variational import *
|
@ -0,0 +1,18 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" Core functionality for Zhusuan """
|
||||
|
||||
from .bn import *
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Bayesian Network """
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class BayesianNet(nn.Cell):
|
||||
"""
|
||||
We currently support 3 types of variables: x = observation, z = latent, y = condition.
|
||||
A Bayeisian Network models a generative process for certain varaiables: p(x,z|y) or p(z|x,y) or p(x|z,y)
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.normal_dist = msd.Normal(dtype=mstype.float32)
|
||||
self.bernoulli_dist = msd.Bernoulli(dtype=mstype.float32)
|
||||
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
|
||||
def Normal(self,
|
||||
name,
|
||||
observation=None,
|
||||
mean=None,
|
||||
std=None,
|
||||
seed=0,
|
||||
dtype=mstype.float32,
|
||||
shape=(),
|
||||
reparameterize=True):
|
||||
""" Normal distribution wrapper """
|
||||
|
||||
assert not name is None
|
||||
assert not seed is None
|
||||
assert not dtype is None
|
||||
|
||||
if observation is None:
|
||||
if reparameterize:
|
||||
epsilon = self.normal_dist('sample', shape, self.zeros(mean.shape), self.ones(std.shape))
|
||||
sample = mean + std * epsilon
|
||||
else:
|
||||
sample = self.normal_dist('sample', shape, mean, std)
|
||||
else:
|
||||
sample = observation
|
||||
|
||||
log_prob = self.reduce_sum(self.normal_dist('log_prob', sample, mean, std), 1)
|
||||
return sample, log_prob
|
||||
|
||||
def Bernoulli(self,
|
||||
name,
|
||||
observation=None,
|
||||
probs=None,
|
||||
seed=0,
|
||||
dtype=mstype.float32,
|
||||
shape=()):
|
||||
""" Bernoulli distribution wrapper """
|
||||
|
||||
assert not name is None
|
||||
assert not seed is None
|
||||
assert not dtype is None
|
||||
|
||||
if observation is None:
|
||||
sample = self.bernoulli_dist('sample', shape, probs)
|
||||
else:
|
||||
sample = observation
|
||||
|
||||
log_prob = self.reduce_sum(self.bernoulli_dist('log_prob', sample, probs), 1)
|
||||
return sample, log_prob
|
||||
|
||||
def construct(self, *inputs, **kwargs):
|
||||
"""
|
||||
We currently fix the parameters of the construct function.
|
||||
Args:
|
||||
the inputs must consist of 3 variables in order.
|
||||
x: data sample, observation
|
||||
z: latent variable
|
||||
y: conditional information
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,18 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" Variational inference related codes """
|
||||
|
||||
from .elbo import *
|
@ -0,0 +1,43 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" ELBO """
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ELBO(nn.Cell):
|
||||
""" ELBO class """
|
||||
def __init__(self, generator, variational):
|
||||
super().__init__()
|
||||
self.generator = generator
|
||||
self.variational = variational
|
||||
self.reshape_op = P.Reshape()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||
self.square = P.Square()
|
||||
|
||||
def construct(self, *inputs, **kwargs):
|
||||
if len(inputs) >= 2:
|
||||
x, y = inputs[0], inputs[1]
|
||||
else:
|
||||
x = inputs[0]
|
||||
y = None
|
||||
|
||||
z, log_prob_z = self.variational(x, None, y)
|
||||
_, log_prob_x_, _, log_prob_z_ = self.generator(x, z, y)
|
||||
|
||||
elbo = self.reduce_mean(log_prob_x_) + self.reduce_mean(log_prob_z_) - self.reduce_mean(log_prob_z)
|
||||
return -elbo
|
@ -0,0 +1,16 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" Zhusuan examples """
|
@ -0,0 +1,15 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" VAE examples """
|
@ -0,0 +1,87 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Utils """
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.transforms.vision.c_transforms as CV
|
||||
from mindspore.dataset.transforms.vision import Inter
|
||||
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
""" create dataset for train or test
|
||||
Args:
|
||||
data_path: Data path
|
||||
batch_size: The number of data records in each group
|
||||
repeat_size: The number of replicated data records
|
||||
num_parallel_workers: The number of parallel workers
|
||||
"""
|
||||
# define dataset
|
||||
mnist_ds = ds.MnistDataset(data_path)
|
||||
#mnist_ds = ds.MnistDataset(data_path,num_samples=32)
|
||||
|
||||
# define operation parameters
|
||||
resize_height, resize_width = 32, 32
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
# define map operations
|
||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # resize images to (32, 32)
|
||||
rescale_op = CV.Rescale(rescale, shift) # rescale images
|
||||
hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
|
||||
type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network
|
||||
|
||||
# apply map operations on images
|
||||
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
buffer_size = 10000
|
||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||
|
||||
return mnist_ds
|
||||
|
||||
|
||||
def save_img(data, name, size=32, num=32):
|
||||
"""
|
||||
Visualize data and save to target files
|
||||
Args:
|
||||
data: nparray of size (num, size, size)
|
||||
name: ouput file name
|
||||
size: image size
|
||||
num: number of images
|
||||
"""
|
||||
col = int(num / 8)
|
||||
row = 8
|
||||
|
||||
imgs = Image.new('L', (size*col, size*row))
|
||||
for i in range(num):
|
||||
j = i/8
|
||||
img_data = data[i]
|
||||
img_data = np.resize(img_data, (size, size))
|
||||
img_data = img_data * 255
|
||||
img_data = img_data.astype(np.uint8)
|
||||
im = Image.fromarray(img_data, 'L')
|
||||
imgs.paste(im, (int(j) * size, (i % 8) * size))
|
||||
imgs.save(name)
|
@ -0,0 +1,165 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" VAE """
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from utils import create_dataset, save_img
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
import zhusuan as zs
|
||||
|
||||
class ReduceMeanLoss(nn.L1Loss):
|
||||
def construct(self, base, target):
|
||||
# return self.get_loss(x)
|
||||
return base
|
||||
|
||||
class Generator(zs.BayesianNet):
|
||||
""" Generator """
|
||||
def __init__(self, x_dim, z_dim, batch_size):
|
||||
super().__init__()
|
||||
self.x_dim = x_dim
|
||||
self.z_dim = z_dim
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.fc1 = nn.Dense(z_dim, 500)
|
||||
self.act1 = nn.ReLU()
|
||||
self.fc2 = nn.Dense(500, 500)
|
||||
self.act2 = nn.ReLU()
|
||||
self.fc3 = nn.Dense(500, x_dim)
|
||||
self.fill = P.Fill()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.reshape_op = P.Reshape()
|
||||
|
||||
def ones(self, shape):
|
||||
return self.fill(mstype.float32, shape, 1.)
|
||||
|
||||
def zeros(self, shape):
|
||||
return self.fill(mstype.float32, shape, 0.)
|
||||
|
||||
def construct(self, x, z, y):
|
||||
""" construct """
|
||||
assert y is None ## we have no conditional information
|
||||
|
||||
if not x is None:
|
||||
x = self.reshape_op(x, (32, 32*32))
|
||||
|
||||
z_mean = self.zeros((self.batch_size, self.z_dim))
|
||||
z_std = self.ones((self.batch_size, self.z_dim))
|
||||
z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=False)
|
||||
|
||||
x_mean = self.sigmoid(self.fc3(self.act2(self.fc2(self.act1(self.fc1(z))))))
|
||||
if x is None:
|
||||
#x = self.bernoulli_dist('sample', (), x_mean)
|
||||
x = x_mean
|
||||
x, log_prob_x = self.Bernoulli('data', observation=x, shape=(), probs=x_mean)
|
||||
|
||||
return x, log_prob_x, z, log_prob_z
|
||||
|
||||
class Variational(zs.BayesianNet):
|
||||
""" Variational """
|
||||
def __init__(self, x_dim, z_dim, batch_size):
|
||||
super().__init__()
|
||||
self.x_dim = x_dim
|
||||
self.z_dim = z_dim
|
||||
self.batch_size = batch_size
|
||||
self.reshape_op = P.Reshape()
|
||||
|
||||
self.fc1 = nn.Dense(x_dim, 500)
|
||||
self.act1 = nn.ReLU()
|
||||
self.fc2 = nn.Dense(500, 500)
|
||||
self.act2 = nn.ReLU()
|
||||
self.fc3 = nn.Dense(500, z_dim)
|
||||
self.fc4 = nn.Dense(500, z_dim)
|
||||
self.fill = P.Fill()
|
||||
self.exp = P.Exp()
|
||||
|
||||
def ones(self, shape):
|
||||
return self.fill(mstype.float32, shape, 1.)
|
||||
|
||||
def zeros(self, shape):
|
||||
return self.fill(mstype.float32, shape, 0.)
|
||||
|
||||
def construct(self, x, z, y):
|
||||
""" construct """
|
||||
assert y is None ## we have no conditional information
|
||||
x = self.reshape_op(x, (32, 32*32))
|
||||
z_logit = self.act2(self.fc2(self.act1(self.fc1(x))))
|
||||
z_mean = self.fc3(z_logit)
|
||||
z_std = self.exp(self.fc4(z_logit))
|
||||
#z, log_prob_z = self.reparameterization(z_mean, z_std)
|
||||
z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=True)
|
||||
return z, log_prob_z
|
||||
|
||||
def main():
|
||||
# We currently support pynative mode with device GPU
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
epoch_size = 1
|
||||
batch_size = 32
|
||||
mnist_path = "/data/chengzi/zhusuan-mindspore/data/MNIST"
|
||||
repeat_size = 1
|
||||
|
||||
# Define model parameters
|
||||
z_dim = 40
|
||||
x_dim = 32*32
|
||||
|
||||
# create the network
|
||||
generator = Generator(x_dim, z_dim, batch_size)
|
||||
variational = Variational(x_dim, z_dim, batch_size)
|
||||
network = zs.variational.ELBO(generator, variational)
|
||||
|
||||
# define loss
|
||||
# learning rate setting
|
||||
lr = 0.001
|
||||
net_loss = ReduceMeanLoss()
|
||||
|
||||
# define the optimizer
|
||||
print(network.trainable_params()[0])
|
||||
net_opt = nn.Adam(network.trainable_params(), lr)
|
||||
|
||||
model = Model(network, net_loss, net_opt)
|
||||
|
||||
ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size, repeat_size)
|
||||
model.train(epoch_size, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
|
||||
print(network.trainable_params()[0])
|
||||
|
||||
iterator = ds_train.create_tuple_iterator()
|
||||
for item in iterator:
|
||||
batch_x = item[0].reshape(32, 32*32)
|
||||
break
|
||||
z, _ = network.variational(Tensor(batch_x), None, None)
|
||||
sample, _, _, _ = network.generator(None, z, None)
|
||||
sample = sample.asnumpy()
|
||||
save_img(batch_x, 'result/origin_x.png')
|
||||
save_img(sample, 'result/reconstruct_x.png')
|
||||
|
||||
for i in range(4):
|
||||
sample, _, _, _ = network.generator(None, None, None)
|
||||
sample = sample.asnumpy()
|
||||
samples = sample if i == 0 else np.concatenate([samples, sample], axis=0)
|
||||
save_img(samples, 'result/sample_x.png', num=4*batch_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue