|
|
|
@ -17,7 +17,6 @@ import numpy as np
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore.context import get_context
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import convert_to_batch, check_greater_equal_zero
|
|
|
|
|
|
|
|
|
@ -39,55 +38,56 @@ class Normal(Distribution):
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
|
|
|
|
|
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32)
|
|
|
|
|
>>> import mindspore.nn.probability.distribution as msd
|
|
|
|
|
>>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # The following creates two independent Normal distributions
|
|
|
|
|
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
|
|
|
|
|
>>> n = msd.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # A normal distribution can be initilize without arguments
|
|
|
|
|
>>> # In this case, mean and sd must be passed in through construct.
|
|
|
|
|
>>> n = nn.Normal(dtype=mstype.float32)
|
|
|
|
|
>>> # A Normal distribution can be initilize without arguments
|
|
|
|
|
>>> # In this case, mean and sd must be passed in through args.
|
|
|
|
|
>>> n = msd.Normal(dtype=mstype.float32)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # To use normal in a network
|
|
|
|
|
>>> # To use Normal in a network
|
|
|
|
|
>>> class net(Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(net, self).__init__():
|
|
|
|
|
>>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32)
|
|
|
|
|
>>> self.n2 = nn.Normal(dtype=mstype.float32)
|
|
|
|
|
>>> self.n1 = msd.Nomral(0.0, 1.0, dtype=mstype.float32)
|
|
|
|
|
>>> self.n2 = msd.Normal(dtype=mstype.float32)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # The following calls are valid in construct
|
|
|
|
|
>>> def construct(self, value, mean_b, sd_b, mean_a, sd_a):
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Similar calls can be made to other probability functions
|
|
|
|
|
>>> # by replacing 'prob' with the name of the function
|
|
|
|
|
>>> ans = self.n1('prob', value)
|
|
|
|
|
>>> ans = self.n1.prob(value)
|
|
|
|
|
>>> # Evaluate with the respect to distribution b
|
|
|
|
|
>>> ans = self.n1('prob', value, mean_b, sd_b)
|
|
|
|
|
>>> ans = self.n1.prob(value, mean_b, sd_b)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # mean and sd must be passed in through construct
|
|
|
|
|
>>> ans = self.n2('prob', value, mean_a, sd_a)
|
|
|
|
|
>>> # mean and sd must be passed in during function calls
|
|
|
|
|
>>> ans = self.n2.prob(value, mean_a, sd_a)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
|
|
|
|
|
>>> # Will return [0.0]
|
|
|
|
|
>>> ans = self.n1('mean')
|
|
|
|
|
>>> # Will return mean_b
|
|
|
|
|
>>> ans = self.n1('mean', mean_b, sd_b)
|
|
|
|
|
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
|
|
|
|
|
>>> # will return [0.0]
|
|
|
|
|
>>> ans = self.n1.mean()
|
|
|
|
|
>>> # will return mean_b
|
|
|
|
|
>>> ans = self.n1.mean(mean_b, sd_b)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # mean and sd must be passed in through construct
|
|
|
|
|
>>> ans = self.n2('mean', mean_a, sd_a)
|
|
|
|
|
>>> # mean and sd must be passed during function calls
|
|
|
|
|
>>> ans = self.n2.mean(mean_a, sd_a)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
|
|
|
|
|
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b)
|
|
|
|
|
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
>>> ans = self.n1.kl_loss('Normal', mean_b, sd_b)
|
|
|
|
|
>>> ans = self.n1.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Additional mean and sd must be passed in through construct
|
|
|
|
|
>>> ans = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
>>> # Additional mean and sd must be passed
|
|
|
|
|
>>> ans = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Sample Usage
|
|
|
|
|
>>> ans = self.n1('sample')
|
|
|
|
|
>>> ans = self.n1('sample', (2,3))
|
|
|
|
|
>>> ans = self.n1('sample', (2,3), mean_b, sd_b)
|
|
|
|
|
>>> ans = self.n2('sample', (2,3), mean_a, sd_a)
|
|
|
|
|
>>> # Sample
|
|
|
|
|
>>> ans = self.n1.sample()
|
|
|
|
|
>>> ans = self.n1.sample((2,3))
|
|
|
|
|
>>> ans = self.n1.sample((2,3), mean_b, sd_b)
|
|
|
|
|
>>> ans = self.n2.sample((2,3), mean_a, sd_a)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
@ -114,7 +114,7 @@ class Normal(Distribution):
|
|
|
|
|
self.const = P.ScalarToArray()
|
|
|
|
|
self.erf = P.Erf()
|
|
|
|
|
self.exp = P.Exp()
|
|
|
|
|
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
|
|
|
|
|
self.expm1 = self._expm1_by_step
|
|
|
|
|
self.fill = P.Fill()
|
|
|
|
|
self.log = P.Log()
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
@ -135,67 +135,57 @@ class Normal(Distribution):
|
|
|
|
|
"""
|
|
|
|
|
return self.exp(x) - 1.0
|
|
|
|
|
|
|
|
|
|
def _mean(self, name='mean', mean=None, sd=None):
|
|
|
|
|
def _mean(self, mean=None, sd=None):
|
|
|
|
|
"""
|
|
|
|
|
Mean of the distribution.
|
|
|
|
|
"""
|
|
|
|
|
if name == 'mean':
|
|
|
|
|
mean = self._mean_value if mean is None or sd is None else mean
|
|
|
|
|
return mean
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _mode(self, name='mode', mean=None, sd=None):
|
|
|
|
|
def _mode(self, mean=None, sd=None):
|
|
|
|
|
"""
|
|
|
|
|
Mode of the distribution.
|
|
|
|
|
"""
|
|
|
|
|
if name == 'mode':
|
|
|
|
|
mean = self._mean_value if mean is None or sd is None else mean
|
|
|
|
|
return mean
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _sd(self, name='sd', mean=None, sd=None):
|
|
|
|
|
def _sd(self, mean=None, sd=None):
|
|
|
|
|
"""
|
|
|
|
|
Standard deviation of the distribution.
|
|
|
|
|
"""
|
|
|
|
|
if name in self._variance_functions:
|
|
|
|
|
sd = self._sd_value if mean is None or sd is None else sd
|
|
|
|
|
return sd
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _entropy(self, name='entropy', sd=None):
|
|
|
|
|
def _entropy(self, sd=None):
|
|
|
|
|
r"""
|
|
|
|
|
Evaluate entropy.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
|
|
|
|
"""
|
|
|
|
|
if name == 'entropy':
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _cross_entropy(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
|
|
|
|
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
|
|
|
|
r"""
|
|
|
|
|
Evaluate cross_entropy between normal distributions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): name of the funtion passed in from construct. Should always be "cross_entropy".
|
|
|
|
|
dist (str): type of the distributions. Should be "Normal" in this case.
|
|
|
|
|
mean_b (Tensor): mean of distribution b.
|
|
|
|
|
sd_b (Tensor): standard deviation distribution b.
|
|
|
|
|
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
|
|
|
|
|
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
|
|
|
|
|
"""
|
|
|
|
|
if name == 'cross_entropy' and dist == 'Normal':
|
|
|
|
|
return self._entropy(sd=sd_a) + self._kl_loss(name, dist, mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
if dist == 'Normal':
|
|
|
|
|
return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _log_prob(self, name, value, mean=None, sd=None):
|
|
|
|
|
def _log_prob(self, value, mean=None, sd=None):
|
|
|
|
|
r"""
|
|
|
|
|
Evaluate log probability.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): name of the funtion passed in from construct.
|
|
|
|
|
value (Tensor): value to be evaluated.
|
|
|
|
|
mean (Tensor): mean of the distribution. Default: self._mean_value.
|
|
|
|
|
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
|
|
|
|
@ -203,20 +193,17 @@ class Normal(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
|
|
|
|
|
"""
|
|
|
|
|
if name in self._prob_functions:
|
|
|
|
|
mean = self._mean_value if mean is None else mean
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
|
|
|
|
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
|
|
|
|
|
return unnormalized_log_prob + neg_normalization
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _cdf(self, name, value, mean=None, sd=None):
|
|
|
|
|
def _cdf(self, value, mean=None, sd=None):
|
|
|
|
|
r"""
|
|
|
|
|
Evaluate cdf of given value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): name of the funtion passed in from construct. Should always be "cdf".
|
|
|
|
|
value (Tensor): value to be evaluated.
|
|
|
|
|
mean (Tensor): mean of the distribution. Default: self._mean_value.
|
|
|
|
|
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
|
|
|
|
@ -224,20 +211,17 @@ class Normal(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
|
|
|
|
|
"""
|
|
|
|
|
if name in self._cdf_survival_functions:
|
|
|
|
|
mean = self._mean_value if mean is None else mean
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
sqrt2 = self.sqrt(self.const(2.0))
|
|
|
|
|
adjusted = (value - mean) / (sd * sqrt2)
|
|
|
|
|
return 0.5 * (1.0 + self.erf(adjusted))
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
|
|
|
|
def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
|
|
|
|
r"""
|
|
|
|
|
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): name of the funtion passed in from construct.
|
|
|
|
|
dist (str): type of the distributions. Should be "Normal" in this case.
|
|
|
|
|
mean_b (Tensor): mean of distribution b.
|
|
|
|
|
sd_b (Tensor): standard deviation distribution b.
|
|
|
|
@ -248,7 +232,7 @@ class Normal(Distribution):
|
|
|
|
|
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
|
|
|
|
|
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
|
|
|
|
|
"""
|
|
|
|
|
if name in self._divergence_functions and dist == 'Normal':
|
|
|
|
|
if dist == 'Normal':
|
|
|
|
|
mean_a = self._mean_value if mean_a is None else mean_a
|
|
|
|
|
sd_a = self._sd_value if sd_a is None else sd_a
|
|
|
|
|
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
|
|
|
@ -256,12 +240,11 @@ class Normal(Distribution):
|
|
|
|
|
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _sample(self, name, shape=(), mean=None, sd=None):
|
|
|
|
|
def _sample(self, shape=(), mean=None, sd=None):
|
|
|
|
|
"""
|
|
|
|
|
Sampling.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): name of the function. Should always be 'sample' when passed in from construct.
|
|
|
|
|
shape (tuple): shape of the sample. Default: ().
|
|
|
|
|
mean (Tensor): mean of the samples. Default: self._mean_value.
|
|
|
|
|
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
|
|
|
|
@ -269,7 +252,6 @@ class Normal(Distribution):
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, shape is shape + batch_shape.
|
|
|
|
|
"""
|
|
|
|
|
if name == 'sample':
|
|
|
|
|
mean = self._mean_value if mean is None else mean
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
|
|
|
|
@ -279,4 +261,3 @@ class Normal(Distribution):
|
|
|
|
|
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
|
|
|
|
|
sample = mean + sample_norm * sd
|
|
|
|
|
return sample
|
|
|
|
|
return None
|
|
|
|
|