From 37d40bc4954ada82d8be32696789fec85a210ab8 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 19 Nov 2020 18:04:46 -0500 Subject: [PATCH] Add Beta distribution --- .../nn/probability/distribution/__init__.py | 2 + mindspore/nn/probability/distribution/beta.py | 333 ++++++++++++++++++ .../nn/probability/distribution/gamma.py | 4 +- .../nn/probability/distribution/poisson.py | 4 +- .../st/probability/distribution/test_beta.py | 245 +++++++++++++ .../st/probability/distribution/test_gamma.py | 6 +- .../nn/probability/distribution/test_beta.py | 212 +++++++++++ 7 files changed, 799 insertions(+), 7 deletions(-) create mode 100644 mindspore/nn/probability/distribution/beta.py create mode 100644 tests/st/probability/distribution/test_beta.py create mode 100644 tests/ut/python/nn/probability/distribution/test_beta.py diff --git a/mindspore/nn/probability/distribution/__init__.py b/mindspore/nn/probability/distribution/__init__.py index c4077376a5..8a5f066181 100644 --- a/mindspore/nn/probability/distribution/__init__.py +++ b/mindspore/nn/probability/distribution/__init__.py @@ -19,6 +19,7 @@ Distributions are the high-level components used to construct the probabilistic from .distribution import Distribution from .transformed_distribution import TransformedDistribution from .bernoulli import Bernoulli +from .beta import Beta from .categorical import Categorical from .cauchy import Cauchy from .exponential import Exponential @@ -34,6 +35,7 @@ from .uniform import Uniform __all__ = ['Distribution', 'TransformedDistribution', 'Bernoulli', + 'Beta', 'Categorical', 'Cauchy', 'Exponential', diff --git a/mindspore/nn/probability/distribution/beta.py b/mindspore/nn/probability/distribution/beta.py new file mode 100644 index 0000000000..6d2d31b6ad --- /dev/null +++ b/mindspore/nn/probability/distribution/beta.py @@ -0,0 +1,333 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Beta Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.ops import composite as C +import mindspore.nn as nn +from mindspore._checkparam import Validator +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import check_greater_zero, check_distribution_name +from ._utils.custom_ops import log_generic + + +class Beta(Distribution): + """ + Beta distribution. + + Args: + concentration1 (int, float, list, numpy.ndarray, Tensor, Parameter): The concentration1, + also know as alpha of the Beta distribution. + concentration0 (int, float, list, numpy.ndarray, Tensor, Parameter): The concentration0, also know as + beta of the Beta distribution. + seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. + dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32. + name (str): The name of the distribution. Default: 'Beta'. + + Note: + `concentration1` and `concentration0` must be greater than zero. + `dist_spec_args` are `concentration1` and `concentration0`. + `dtype` must be a float type because Beta distributions are continuous. + + Examples: + >>> # To initialize a Beta distribution of the concentration1 3.0 and the concentration0 4.0. + >>> import mindspore.nn.probability.distribution as msd + >>> b = msd.Beta(3.0, 4.0, dtype=mstype.float32) + >>> + >>> # The following creates two independent Beta distributions. + >>> b = msd.Beta([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + >>> + >>> # A Beta distribution can be initilized without arguments. + >>> # In this case, `concentration1` and `concentration0` must be passed in through arguments. + >>> b = msd.Beta(dtype=mstype.float32) + >>> + >>> # To use a Beta distribution in a network. + >>> class net(Cell): + ... def __init__(self): + ... super(net, self).__init__(): + ... self.b1 = msd.Beta(1.0, 1.0, dtype=mstype.float32) + ... self.b2 = msd.Beta(dtype=mstype.float32) + ... + ... # The following calls are valid in construct. + ... def construct(self, value, concentration1_b, concentration0_b, concentration1_a, concentration0_a): + ... + ... # Private interfaces of probability functions corresponding to public interfaces, including + ... # `prob` and `log_prob`, have the same arguments as follows. + ... # Args: + ... # value (Tensor): the value to be evaluated. + ... # concentration1 (Tensor): the concentration1 of the distribution. Default: self._concentration1. + ... # concentration0 (Tensor): the concentration0 of the distribution. Default: self._concentration0. + ... + ... # Examples of `prob`. + ... # Similar calls can be made to other probability functions + ... # by replacing 'prob' by the name of the function + ... ans = self.b1.prob(value) + ... # Evaluate with respect to the distribution b. + ... ans = self.b1.prob(value, concentration1_b, concentration0_b) + ... # `concentration1` and `concentration0` must be passed in during function calls + ... ans = self.b2.prob(value, concentration1_a, concentration0_a) + ... + ... + ... # Functions `mean`, `sd`, `mode`, `var`, and `entropy` have the same arguments. + ... # Args: + ... # concentration1 (Tensor): the concentration1 of the distribution. Default: self._concentration1. + ... # concentration0 (Tensor): the concentration0 of the distribution. Default: self._concentration0. + ... + ... # Example of `mean`, `sd`, `mode`, `var`, and `entropy` are similar. + ... ans = self.b1.concentration1() # return 1.0 + ... ans = self.b1.concentration1(concentration1_b, concentration0_b) # return concentration1_b + ... # `concentration1` and `concentration0` must be passed in during function calls. + ... ans = self.b2.concentration1(concentration1_a, concentration0_a) + ... + ... + ... # Interfaces of 'kl_loss' and 'cross_entropy' are the same: + ... # Args: + ... # dist (str): the type of the distributions. Only "Beta" is supported. + ... # concentration1_b (Tensor): the concentration1 of distribution b. + ... # concentration0_b (Tensor): the concentration0 of distribution b. + ... # concentration1_a (Tensor): the concentration1 of distribution a. + ... # Default: self._concentration1. + ... # concentration0_a (Tensor): the concentration0 of distribution a. + ... # Default: self._concentration0. + ... + ... # Examples of `kl_loss`. `cross_entropy` is similar. + ... ans = self.b1.kl_loss('Beta', concentration1_b, concentration0_b) + ... ans = self.b1.kl_loss('Beta', concentration1_b, concentration0_b, + ... concentration1_a, concentration0_a) + ... # Additional `concentration1` and `concentration0` must be passed in. + ... ans = self.b2.kl_loss('Beta', concentration1_b, concentration0_b, + ... concentration1_a, concentration0_a) + ... + ... + ... # Examples of `sample`. + ... # Args: + ... # shape (tuple): the shape of the sample. Default: () + ... # concentration1 (Tensor): the concentration1 of the distribution. Default: self._concentration1. + ... # concentration0 (Tensor): the concentration0 of the distribution. Default: self._concentration0. + ... ans = self.b1.sample() + ... ans = self.b1.sample((2,3)) + ... ans = self.b1.sample((2,3), concentration1_b, concentration0_b) + ... ans = self.b2.sample((2,3), concentration1_a, concentration0_a) + """ + + def __init__(self, + concentration1=None, + concentration0=None, + seed=None, + dtype=mstype.float32, + name="Beta"): + """ + Constructor of Beta. + """ + param = dict(locals()) + param['param_dict'] = {'concentration1': concentration1, 'concentration0': concentration0} + valid_dtype = mstype.float_type + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) + super(Beta, self).__init__(seed, dtype, name, param) + + self._concentration1 = self._add_parameter(concentration1, 'concentration1') + self._concentration0 = self._add_parameter(concentration0, 'concentration0') + if self._concentration1 is not None: + check_greater_zero(self._concentration1, "concentration1") + if self._concentration0 is not None: + check_greater_zero(self._concentration0, "concentration0") + + # ops needed for the class + self.log = log_generic + self.log1p = P.Log1p() + self.neg = P.Neg() + self.pow = P.Pow() + self.squeeze = P.Squeeze(0) + self.cast = P.Cast() + self.fill = P.Fill() + self.shape = P.Shape() + self.select = P.Select() + self.logicaland = P.LogicalAnd() + self.greater = P.Greater() + self.digamma = nn.DiGamma() + self.lbeta = nn.LBeta() + + def extend_repr(self): + if self.is_scalar_batch: + s = f'concentration1 = {self._concentration1}, concentration0 = {self._concentration0}' + else: + s = f'batch_shape = {self._broadcast_shape}' + return s + + @property + def concentration1(self): + """ + Return the concentration1, also know as the alpha of the Beta distribution. + """ + return self._concentration1 + + @property + def concentration0(self): + """ + Return the concentration0, also know as the beta of the Beta distribution. + """ + return self._concentration0 + + def _get_dist_type(self): + return "Beta" + + def _get_dist_args(self, concentration1=None, concentration0=None): + if concentration1 is not None: + self.checktensor(concentration1, 'concentration1') + else: + concentration1 = self._concentration1 + if concentration0 is not None: + self.checktensor(concentration0, 'concentration0') + else: + concentration0 = self._concentration0 + return concentration1, concentration0 + + def _mean(self, concentration1=None, concentration0=None): + """ + The mean of the distribution. + """ + concentration1, concentration0 = self._check_param_type(concentration1, concentration0) + return concentration1 / (concentration1 + concentration0) + + def _var(self, concentration1=None, concentration0=None): + """ + The variance of the distribution. + """ + concentration1, concentration0 = self._check_param_type(concentration1, concentration0) + total_concentration = concentration1 + concentration0 + return concentration1 * concentration0 / (self.pow(total_concentration, 2) * (total_concentration + 1.)) + + def _mode(self, concentration1=None, concentration0=None): + """ + The mode of the distribution. + """ + concentration1, concentration0 = self._check_param_type(concentration1, concentration0) + comp1 = self.greater(concentration1, 1.) + comp2 = self.greater(concentration0, 1.) + cond = self.logicaland(comp1, comp2) + nan = self.fill(self.dtype, self.broadcast_shape, np.nan) + mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.) + return self.select(cond, mode, nan) + + def _entropy(self, concentration1=None, concentration0=None): + r""" + Evaluate entropy. + + .. math:: + H(X) = \log(\Beta(\alpha, \beta)) - (\alpha - 1) * \digamma(\alpha) + - (\beta - 1) * \digamma(\beta) + (\alpha + \beta - 2) * \digamma(\alpha + \beta) + """ + concentration1, concentration0 = self._check_param_type(concentration1, concentration0) + total_concentration = concentration1 + concentration0 + return self.lbeta(concentration1, concentration0) \ + - (concentration1 - 1.) * self.digamma(concentration1) \ + - (concentration0 - 1.) * self.digamma(concentration0) \ + + (total_concentration - 2.) * self.digamma(total_concentration) + + def _cross_entropy(self, dist, concentration1_b, concentration0_b, concentration1=None, concentration0=None): + r""" + Evaluate cross entropy between Beta distributions. + + Args: + dist (str): Type of the distributions. Should be "Beta" in this case. + concentration1_b (Tensor): concentration1 of distribution b. + concentration0_b (Tensor): concentration0 of distribution b. + concentration1_a (Tensor): concentration1 of distribution a. Default: self._concentration1. + concentration0_a (Tensor): concentration0 of distribution a. Default: self._concentration0. + """ + check_distribution_name(dist, 'Beta') + return self._entropy(concentration1, concentration0) \ + + self._kl_loss(dist, concentration1_b, concentration0_b, concentration1, concentration0) + + def _log_prob(self, value, concentration1=None, concentration0=None): + r""" + Evaluate log probability. + + Args: + value (Tensor): The value to be evaluated. + concentration1 (Tensor): The concentration1 of the distribution. Default: self._concentration1. + concentration0 (Tensor): The concentration0 the distribution. Default: self._concentration0. + + .. math:: + L(x) = (\alpha - 1) * \log(x) + (\beta - 1) * \log(1 - x) - \log(\Beta(\alpha, \beta)) + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + concentration1, concentration0 = self._check_param_type(concentration1, concentration0) + log_unnormalized_prob = (concentration1 - 1.) * self.log(value) \ + + (concentration0 - 1.) * self.log1p(self.neg(value)) + return log_unnormalized_prob - self.lbeta(concentration1, concentration0) + + def _kl_loss(self, dist, concentration1_b, concentration0_b, concentration1=None, concentration0=None): + r""" + Evaluate Beta-Beta KL divergence, i.e. KL(a||b). + + Args: + dist (str): The type of the distributions. Should be "Beta" in this case. + concentration1_b (Tensor): The concentration1 of distribution b. + concentration0_b (Tensor): The concentration0 distribution b. + concentration1_a (Tensor): The concentration1 of distribution a. Default: self._concentration1. + concentration0_a (Tensor): The concentration0 distribution a. Default: self._concentration0. + + .. math:: + KL(a||b) = \log(\Beta(\alpha_{b}, \beta_{b})) - \log(\Beta(\alpha_{a}, \beta_{a})) + - \digamma(\alpha_{a}) * (\alpha_{b} - \alpha_{a}) + - \digamma(\beta_{a}) * (\beta_{b} - \beta_{a}) + + \digamma(\alpha_{a} + \beta_{a}) * (\alpha_{b} + \beta_{b} - \alpha_{a} - \beta_{a}) + """ + check_distribution_name(dist, 'Beta') + concentration1_b = self._check_value(concentration1_b, 'concentration1_b') + concentration0_b = self._check_value(concentration0_b, 'concentration0_b') + concentration1_b = self.cast(concentration1_b, self.parameter_type) + concentration0_b = self.cast(concentration0_b, self.parameter_type) + concentration1_a, concentration0_a = self._check_param_type(concentration1, concentration0) + total_concentration_a = concentration1_a + concentration0_a + total_concentration_b = concentration1_b + concentration0_b + log_normalization_a = self.lbeta(concentration1_a, concentration0_a) + log_normalization_b = self.lbeta(concentration1_b, concentration0_b) + return (log_normalization_b - log_normalization_a) \ + - (self.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ + - (self.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ + + (self.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) + + def _sample(self, shape=(), concentration1=None, concentration0=None): + """ + Sampling. + + Args: + shape (tuple): The shape of the sample. Default: (). + concentration1 (Tensor): The concentration1 of the samples. Default: self._concentration1. + concentration0 (Tensor): The concentration0 of the samples. Default: self._concentration0. + + Returns: + Tensor, with the shape being shape + batch_shape. + """ + shape = self.checktuple(shape, 'shape') + concentration1, concentration0 = self._check_param_type(concentration1, concentration0) + batch_shape = self.shape(concentration1 + concentration0) + origin_shape = shape + batch_shape + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape + ones = self.fill(self.dtype, sample_shape, 1.0) + sample_gamma1 = C.gamma(sample_shape, alpha=concentration1, beta=ones, seed=self.seed) + sample_gamma2 = C.gamma(sample_shape, alpha=concentration0, beta=ones, seed=self.seed) + sample_beta = sample_gamma1 / (sample_gamma1 + sample_gamma2) + value = self.cast(sample_beta, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/mindspore/nn/probability/distribution/gamma.py b/mindspore/nn/probability/distribution/gamma.py index 17e946aca4..93c3fea834 100644 --- a/mindspore/nn/probability/distribution/gamma.py +++ b/mindspore/nn/probability/distribution/gamma.py @@ -81,12 +81,12 @@ class Gamma(Distribution): ... ans = self.g2.prob(value, concentration_a, rate_a) ... ... - ... # Functions `concentration`, `rate`, `mean`, `sd`, `var`, and `entropy` have the same arguments. + ... # Functions `mean`, `sd`, `mode`, `var`, and `entropy` have the same arguments. ... # Args: ... # concentration (Tensor): the concentration of the distribution. Default: self._concentration. ... # rate (Tensor): the rate of the distribution. Default: self._rate. ... - ... # Example of `concentration`, `rate`, `mean`. `sd`, `var`, and `entropy` are similar. + ... # Example of `mean`, `sd`, `mode`, `var`, and `entropy` are similar. ... ans = self.g1.concentration() # return 1.0 ... ans = self.g1.concentration(concentration_b, rate_b) # return concentration_b ... # `concentration` and `rate` must be passed in during function calls. diff --git a/mindspore/nn/probability/distribution/poisson.py b/mindspore/nn/probability/distribution/poisson.py index f9f77cfd17..d726f197e0 100644 --- a/mindspore/nn/probability/distribution/poisson.py +++ b/mindspore/nn/probability/distribution/poisson.py @@ -76,11 +76,11 @@ class Poisson(Distribution): ... ans = self.p2.prob(value, rate_a) ... ... - ... # Functions `mean`, `sd`, and 'var' have the same arguments as follows. + ... # Functions `mean`, `mode`, `sd`, and 'var' have the same arguments as follows. ... # Args: ... # rate (Tensor): the rate of the distribution. Default: self.rate. ... - ... # Examples of `mean`. `sd`, `var`, and `entropy` are similar. + ... # Examples of `mean`, `sd`, `mode`, `var`, and `entropy` are similar. ... ans = self.p1.mean() # return 2 ... ans = self.p1.mean(rate_b) # return 1 / rate_b ... # `rate` must be passed in during function calls. diff --git a/tests/st/probability/distribution/test_beta.py b/tests/st/probability/distribution/test_beta.py new file mode 100644 index 0000000000..60a7c5f12f --- /dev/null +++ b/tests/st/probability/distribution/test_beta.py @@ -0,0 +1,245 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Beta distribution""" +import numpy as np +from scipy import stats +from scipy import special +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Beta distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) + + def construct(self, x_): + return self.b.prob(x_) + +def test_pdf(): + """ + Test pdf. + """ + beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) + expect_pdf = beta_benchmark.pdf([0.25, 0.75]).astype(np.float32) + pdf = Prob() + output = pdf(Tensor([0.25, 0.75], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Beta distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) + + def construct(self, x_): + return self.b.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) + expect_logpdf = beta_benchmark.logpdf([0.25, 0.75]).astype(np.float32) + logprob = LogProb() + output = logprob(Tensor([0.25, 0.75], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss of Beta distribution. + """ + def __init__(self): + super(KL, self).__init__() + self.b = msd.Beta(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + def construct(self, x_, y_): + return self.b.kl_loss('Beta', x_, y_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + concentration1_a = np.array([3.0]).astype(np.float32) + concentration0_a = np.array([4.0]).astype(np.float32) + + concentration1_b = np.array([1.0]).astype(np.float32) + concentration0_b = np.array([1.0]).astype(np.float32) + + total_concentration_a = concentration1_a + concentration0_a + total_concentration_b = concentration1_b + concentration0_b + log_normalization_a = np.log(special.beta(concentration1_a, concentration0_a)) + log_normalization_b = np.log(special.beta(concentration1_b, concentration0_b)) + expect_kl_loss = (log_normalization_b - log_normalization_a) \ + - (special.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ + - (special.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ + + (special.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) + + kl_loss = KL() + concentration1 = Tensor(concentration1_b, dtype=dtype.float32) + concentration0 = Tensor(concentration0_b, dtype=dtype.float32) + output = kl_loss(concentration1, concentration0) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Beta distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.b = msd.Beta(np.array([3.0]), np.array([3.0]), dtype=dtype.float32) + + def construct(self): + return self.b.mean(), self.b.sd(), self.b.mode() + +def test_basics(): + """ + Test mean/standard deviation/mode. + """ + basics = Basics() + mean, sd, mode = basics() + beta_benchmark = stats.beta(np.array([3.0]), np.array([3.0])) + expect_mean = beta_benchmark.mean().astype(np.float32) + expect_sd = beta_benchmark.std().astype(np.float32) + expect_mode = [0.5] + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Beta distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.b = msd.Beta(np.array([3.0]), np.array([1.0]), seed=seed, dtype=dtype.float32) + self.shape = shape + + def construct(self, concentration1=None, concentration0=None): + return self.b.sample(self.shape, concentration1, concentration0) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + concentration1 = Tensor([2.0], dtype=dtype.float32) + concentration0 = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(concentration1, concentration0) + assert output.shape == (2, 3, 3) + +class EntropyH(nn.Cell): + """ + Test class: entropy of Beta distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) + + def construct(self): + return self.b.entropy() + +def test_entropy(): + """ + Test entropy. + """ + beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) + expect_entropy = beta_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Beta distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) + + def construct(self, x_, y_): + entropy = self.b.entropy() + kl_loss = self.b.kl_loss('Beta', x_, y_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.b.cross_entropy('Beta', x_, y_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + concentration1 = Tensor([3.0], dtype=dtype.float32) + concentration0 = Tensor([2.0], dtype=dtype.float32) + diff = cross_entropy(concentration1, concentration0) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() + +class Net(nn.Cell): + """ + Test class: expand single distribution instance to multiple graphs + by specifying the attributes. + """ + + def __init__(self): + super(Net, self).__init__() + self.beta = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) + + def construct(self, x_, y_): + kl = self.beta.kl_loss('Beta', x_, y_) + prob = self.beta.prob(kl) + return prob + +def test_multiple_graphs(): + """ + Test multiple graphs case. + """ + prob = Net() + concentration1_a = np.array([3.0]).astype(np.float32) + concentration0_a = np.array([1.0]).astype(np.float32) + concentration1_b = np.array([2.0]).astype(np.float32) + concentration0_b = np.array([1.0]).astype(np.float32) + ans = prob(Tensor(concentration1_b), Tensor(concentration0_b)) + + total_concentration_a = concentration1_a + concentration0_a + total_concentration_b = concentration1_b + concentration0_b + log_normalization_a = np.log(special.beta(concentration1_a, concentration0_a)) + log_normalization_b = np.log(special.beta(concentration1_b, concentration0_b)) + expect_kl_loss = (log_normalization_b - log_normalization_a) \ + - (special.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ + - (special.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ + + (special.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) + + beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) + expect_prob = beta_benchmark.pdf(expect_kl_loss).astype(np.float32) + + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() diff --git a/tests/st/probability/distribution/test_gamma.py b/tests/st/probability/distribution/test_gamma.py index 57e45844e4..6fb796ed09 100644 --- a/tests/st/probability/distribution/test_gamma.py +++ b/tests/st/probability/distribution/test_gamma.py @@ -298,11 +298,11 @@ class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.Gamma = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) + self.get_flags = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) def construct(self, x_, y_): - kl = self.Gamma.kl_loss('Gamma', x_, y_) - prob = self.Gamma.prob(kl) + kl = self.g.kl_loss('Gamma', x_, y_) + prob = self.g.prob(kl) return prob def test_multiple_graphs(): diff --git a/tests/ut/python/nn/probability/distribution/test_beta.py b/tests/ut/python/nn/probability/distribution/test_beta.py new file mode 100644 index 0000000000..fb378ee44a --- /dev/null +++ b/tests/ut/python/nn/probability/distribution/test_beta.py @@ -0,0 +1,212 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Gamma. +""" +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + +def test_gamma_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + msd.Gamma([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + +def test_type(): + with pytest.raises(TypeError): + msd.Gamma(0., 1., dtype=dtype.int32) + +def test_name(): + with pytest.raises(TypeError): + msd.Gamma(0., 1., name=1.0) + +def test_seed(): + with pytest.raises(TypeError): + msd.Gamma(0., 1., seed='seed') + +def test_concentration1(): + with pytest.raises(ValueError): + msd.Gamma(0., 1.) + with pytest.raises(ValueError): + msd.Gamma(-1., 1.) + +def test_concentration0(): + with pytest.raises(ValueError): + msd.Gamma(1., 0.) + with pytest.raises(ValueError): + msd.Gamma(1., -1.) + +def test_arguments(): + """ + args passing during initialization. + """ + g = msd.Gamma() + assert isinstance(g, msd.Distribution) + g = msd.Gamma([3.0], [4.0], dtype=dtype.float32) + assert isinstance(g, msd.Distribution) + + +class GammaProb(nn.Cell): + """ + Gamma distribution: initialize with concentration1/concentration0. + """ + def __init__(self): + super(GammaProb, self).__init__() + self.gamma = msd.Gamma([3.0, 4.0], [1.0, 1.0], dtype=dtype.float32) + + def construct(self, value): + prob = self.gamma.prob(value) + log_prob = self.gamma.log_prob(value) + return prob + log_prob + +def test_gamma_prob(): + """ + Test probability functions: passing value through construct. + """ + net = GammaProb() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + + +class GammaProb1(nn.Cell): + """ + Gamma distribution: initialize without concentration1/concentration0. + """ + def __init__(self): + super(GammaProb1, self).__init__() + self.gamma = msd.Gamma() + + def construct(self, value, concentration1, concentration0): + prob = self.gamma.prob(value, concentration1, concentration0) + log_prob = self.gamma.log_prob(value, concentration1, concentration0) + return prob + log_prob + +def test_gamma_prob1(): + """ + Test probability functions: passing concentration1/concentration0, value through construct. + """ + net = GammaProb1() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + concentration1 = Tensor([2.0, 3.0], dtype=dtype.float32) + concentration0 = Tensor([1.0], dtype=dtype.float32) + ans = net(value, concentration1, concentration0) + assert isinstance(ans, Tensor) + +class GammaKl(nn.Cell): + """ + Test class: kl_loss of Gamma distribution. + """ + def __init__(self): + super(GammaKl, self).__init__() + self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.g2 = msd.Gamma(dtype=dtype.float32) + + def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a): + kl1 = self.g1.kl_loss('Gamma', concentration1_b, concentration0_b) + kl2 = self.g2.kl_loss('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss. + """ + net = GammaKl() + concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a) + assert isinstance(ans, Tensor) + +class GammaCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Gamma distribution. + """ + def __init__(self): + super(GammaCrossEntropy, self).__init__() + self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.g2 = msd.Gamma(dtype=dtype.float32) + + def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a): + h1 = self.g1.cross_entropy('Gamma', concentration1_b, concentration0_b) + h2 = self.g2.cross_entropy('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross entropy between Gamma distributions. + """ + net = GammaCrossEntropy() + concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a) + assert isinstance(ans, Tensor) + +class GammaBasics(nn.Cell): + """ + Test class: basic mean/sd function. + """ + def __init__(self): + super(GammaBasics, self).__init__() + self.g = msd.Gamma(np.array([3.0, 4.0]), np.array([4.0, 6.0]), dtype=dtype.float32) + + def construct(self): + mean = self.g.mean() + sd = self.g.sd() + mode = self.g.mode() + return mean + sd + mode + +def test_bascis(): + """ + Test mean/sd/mode/entropy functionality of Gamma. + """ + net = GammaBasics() + ans = net() + assert isinstance(ans, Tensor) + +class GammaConstruct(nn.Cell): + """ + Gamma distribution: going through construct. + """ + def __init__(self): + super(GammaConstruct, self).__init__() + self.gamma = msd.Gamma([3.0], [4.0]) + self.gamma1 = msd.Gamma() + + def construct(self, value, concentration1, concentration0): + prob = self.gamma('prob', value) + prob1 = self.gamma('prob', value, concentration1, concentration0) + prob2 = self.gamma1('prob', value, concentration1, concentration0) + return prob + prob1 + prob2 + +def test_gamma_construct(): + """ + Test probability function going through construct. + """ + net = GammaConstruct() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + concentration1 = Tensor([0.0], dtype=dtype.float32) + concentration0 = Tensor([1.0], dtype=dtype.float32) + ans = net(value, concentration1, concentration0) + assert isinstance(ans, Tensor)