diff --git a/mindspore/nn/probability/zhusuan/__init__.py b/mindspore/nn/probability/zhusuan/__init__.py new file mode 100644 index 0000000000..d4330d1162 --- /dev/null +++ b/mindspore/nn/probability/zhusuan/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Zhusuan package: a probalistic programming library """ + +from .framework import * +from .variational import * diff --git a/mindspore/nn/probability/zhusuan/framework/__init__.py b/mindspore/nn/probability/zhusuan/framework/__init__.py new file mode 100644 index 0000000000..680eee8d88 --- /dev/null +++ b/mindspore/nn/probability/zhusuan/framework/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" Core functionality for Zhusuan """ + +from .bn import * diff --git a/mindspore/nn/probability/zhusuan/framework/bn.py b/mindspore/nn/probability/zhusuan/framework/bn.py new file mode 100644 index 0000000000..e03dead0c6 --- /dev/null +++ b/mindspore/nn/probability/zhusuan/framework/bn.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Bayesian Network """ + +import mindspore.nn as nn + +import mindspore.nn.probability.distribution as msd +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + +class BayesianNet(nn.Cell): + """ + We currently support 3 types of variables: x = observation, z = latent, y = condition. + A Bayeisian Network models a generative process for certain varaiables: p(x,z|y) or p(z|x,y) or p(x|z,y) + """ + def __init__(self): + super().__init__() + self.normal_dist = msd.Normal(dtype=mstype.float32) + self.bernoulli_dist = msd.Bernoulli(dtype=mstype.float32) + + self.reduce_sum = P.ReduceSum(keep_dims=True) + + def Normal(self, + name, + observation=None, + mean=None, + std=None, + seed=0, + dtype=mstype.float32, + shape=(), + reparameterize=True): + """ Normal distribution wrapper """ + + assert not name is None + assert not seed is None + assert not dtype is None + + if observation is None: + if reparameterize: + epsilon = self.normal_dist('sample', shape, self.zeros(mean.shape), self.ones(std.shape)) + sample = mean + std * epsilon + else: + sample = self.normal_dist('sample', shape, mean, std) + else: + sample = observation + + log_prob = self.reduce_sum(self.normal_dist('log_prob', sample, mean, std), 1) + return sample, log_prob + + def Bernoulli(self, + name, + observation=None, + probs=None, + seed=0, + dtype=mstype.float32, + shape=()): + """ Bernoulli distribution wrapper """ + + assert not name is None + assert not seed is None + assert not dtype is None + + if observation is None: + sample = self.bernoulli_dist('sample', shape, probs) + else: + sample = observation + + log_prob = self.reduce_sum(self.bernoulli_dist('log_prob', sample, probs), 1) + return sample, log_prob + + def construct(self, *inputs, **kwargs): + """ + We currently fix the parameters of the construct function. + Args: + the inputs must consist of 3 variables in order. + x: data sample, observation + z: latent variable + y: conditional information + """ + raise NotImplementedError diff --git a/mindspore/nn/probability/zhusuan/variational/__init__.py b/mindspore/nn/probability/zhusuan/variational/__init__.py new file mode 100644 index 0000000000..fd591d1847 --- /dev/null +++ b/mindspore/nn/probability/zhusuan/variational/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" Variational inference related codes """ + +from .elbo import * diff --git a/mindspore/nn/probability/zhusuan/variational/elbo.py b/mindspore/nn/probability/zhusuan/variational/elbo.py new file mode 100644 index 0000000000..fd3bdd13eb --- /dev/null +++ b/mindspore/nn/probability/zhusuan/variational/elbo.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" ELBO """ + +import mindspore.nn as nn + +from mindspore.ops import operations as P + + +class ELBO(nn.Cell): + """ ELBO class """ + def __init__(self, generator, variational): + super().__init__() + self.generator = generator + self.variational = variational + self.reshape_op = P.Reshape() + self.reduce_mean = P.ReduceMean(keep_dims=False) + self.square = P.Square() + + def construct(self, *inputs, **kwargs): + if len(inputs) >= 2: + x, y = inputs[0], inputs[1] + else: + x = inputs[0] + y = None + + z, log_prob_z = self.variational(x, None, y) + _, log_prob_x_, _, log_prob_z_ = self.generator(x, z, y) + + elbo = self.reduce_mean(log_prob_x_) + self.reduce_mean(log_prob_z_) - self.reduce_mean(log_prob_z) + return -elbo diff --git a/tests/st/probability/zhusuan/__init__.py b/tests/st/probability/zhusuan/__init__.py new file mode 100755 index 0000000000..2a21496002 --- /dev/null +++ b/tests/st/probability/zhusuan/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" Zhusuan examples """ diff --git a/tests/st/probability/zhusuan/vae/__init__.py b/tests/st/probability/zhusuan/vae/__init__.py new file mode 100755 index 0000000000..83d04fdddd --- /dev/null +++ b/tests/st/probability/zhusuan/vae/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" VAE examples """ diff --git a/tests/st/probability/zhusuan/vae/utils.py b/tests/st/probability/zhusuan/vae/utils.py new file mode 100755 index 0000000000..937148b756 --- /dev/null +++ b/tests/st/probability/zhusuan/vae/utils.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Utils """ + +from PIL import Image +import numpy as np + +from mindspore.common import dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as CV +from mindspore.dataset.transforms.vision import Inter + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ create dataset for train or test + Args: + data_path: Data path + batch_size: The number of data records in each group + repeat_size: The number of replicated data records + num_parallel_workers: The number of parallel workers + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + #mnist_ds = ds.MnistDataset(data_path,num_samples=32) + + # define operation parameters + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # resize images to (32, 32) + rescale_op = CV.Rescale(rescale, shift) # rescale images + hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network. + type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +def save_img(data, name, size=32, num=32): + """ + Visualize data and save to target files + Args: + data: nparray of size (num, size, size) + name: ouput file name + size: image size + num: number of images + """ + col = int(num / 8) + row = 8 + + imgs = Image.new('L', (size*col, size*row)) + for i in range(num): + j = i/8 + img_data = data[i] + img_data = np.resize(img_data, (size, size)) + img_data = img_data * 255 + img_data = img_data.astype(np.uint8) + im = Image.fromarray(img_data, 'L') + imgs.paste(im, (int(j) * size, (i % 8) * size)) + imgs.save(name) diff --git a/tests/st/probability/zhusuan/vae/vae_mnist.py b/tests/st/probability/zhusuan/vae/vae_mnist.py new file mode 100644 index 0000000000..a9cee6bf63 --- /dev/null +++ b/tests/st/probability/zhusuan/vae/vae_mnist.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" VAE """ + +import os +import numpy as np + +from utils import create_dataset, save_img + +import mindspore.nn as nn + +from mindspore import context +from mindspore import Tensor +from mindspore.train import Model +from mindspore.train.callback import LossMonitor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +import zhusuan as zs + +class ReduceMeanLoss(nn.L1Loss): + def construct(self, base, target): + # return self.get_loss(x) + return base + +class Generator(zs.BayesianNet): + """ Generator """ + def __init__(self, x_dim, z_dim, batch_size): + super().__init__() + self.x_dim = x_dim + self.z_dim = z_dim + self.batch_size = batch_size + + self.fc1 = nn.Dense(z_dim, 500) + self.act1 = nn.ReLU() + self.fc2 = nn.Dense(500, 500) + self.act2 = nn.ReLU() + self.fc3 = nn.Dense(500, x_dim) + self.fill = P.Fill() + self.sigmoid = P.Sigmoid() + self.reshape_op = P.Reshape() + + def ones(self, shape): + return self.fill(mstype.float32, shape, 1.) + + def zeros(self, shape): + return self.fill(mstype.float32, shape, 0.) + + def construct(self, x, z, y): + """ construct """ + assert y is None ## we have no conditional information + + if not x is None: + x = self.reshape_op(x, (32, 32*32)) + + z_mean = self.zeros((self.batch_size, self.z_dim)) + z_std = self.ones((self.batch_size, self.z_dim)) + z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=False) + + x_mean = self.sigmoid(self.fc3(self.act2(self.fc2(self.act1(self.fc1(z)))))) + if x is None: + #x = self.bernoulli_dist('sample', (), x_mean) + x = x_mean + x, log_prob_x = self.Bernoulli('data', observation=x, shape=(), probs=x_mean) + + return x, log_prob_x, z, log_prob_z + +class Variational(zs.BayesianNet): + """ Variational """ + def __init__(self, x_dim, z_dim, batch_size): + super().__init__() + self.x_dim = x_dim + self.z_dim = z_dim + self.batch_size = batch_size + self.reshape_op = P.Reshape() + + self.fc1 = nn.Dense(x_dim, 500) + self.act1 = nn.ReLU() + self.fc2 = nn.Dense(500, 500) + self.act2 = nn.ReLU() + self.fc3 = nn.Dense(500, z_dim) + self.fc4 = nn.Dense(500, z_dim) + self.fill = P.Fill() + self.exp = P.Exp() + + def ones(self, shape): + return self.fill(mstype.float32, shape, 1.) + + def zeros(self, shape): + return self.fill(mstype.float32, shape, 0.) + + def construct(self, x, z, y): + """ construct """ + assert y is None ## we have no conditional information + x = self.reshape_op(x, (32, 32*32)) + z_logit = self.act2(self.fc2(self.act1(self.fc1(x)))) + z_mean = self.fc3(z_logit) + z_std = self.exp(self.fc4(z_logit)) + #z, log_prob_z = self.reparameterization(z_mean, z_std) + z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=True) + return z, log_prob_z + +def main(): + # We currently support pynative mode with device GPU + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + epoch_size = 1 + batch_size = 32 + mnist_path = "/data/chengzi/zhusuan-mindspore/data/MNIST" + repeat_size = 1 + + # Define model parameters + z_dim = 40 + x_dim = 32*32 + + # create the network + generator = Generator(x_dim, z_dim, batch_size) + variational = Variational(x_dim, z_dim, batch_size) + network = zs.variational.ELBO(generator, variational) + + # define loss + # learning rate setting + lr = 0.001 + net_loss = ReduceMeanLoss() + + # define the optimizer + print(network.trainable_params()[0]) + net_opt = nn.Adam(network.trainable_params(), lr) + + model = Model(network, net_loss, net_opt) + + ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size, repeat_size) + model.train(epoch_size, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) + + print(network.trainable_params()[0]) + + iterator = ds_train.create_tuple_iterator() + for item in iterator: + batch_x = item[0].reshape(32, 32*32) + break + z, _ = network.variational(Tensor(batch_x), None, None) + sample, _, _, _ = network.generator(None, z, None) + sample = sample.asnumpy() + save_img(batch_x, 'result/origin_x.png') + save_img(sample, 'result/reconstruct_x.png') + + for i in range(4): + sample, _, _, _ = network.generator(None, None, None) + sample = sample.asnumpy() + samples = sample if i == 0 else np.concatenate([samples, sample], axis=0) + save_img(samples, 'result/sample_x.png', num=4*batch_size) + +if __name__ == '__main__': + main()