|
|
|
@ -18,8 +18,8 @@ from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import convert_to_batch, check_greater_zero, check_type
|
|
|
|
|
|
|
|
|
|
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
|
|
|
|
raise_none_error
|
|
|
|
|
|
|
|
|
|
class Normal(Distribution):
|
|
|
|
|
"""
|
|
|
|
@ -103,9 +103,10 @@ class Normal(Distribution):
|
|
|
|
|
valid_dtype = mstype.float_type
|
|
|
|
|
check_type(dtype, valid_dtype, "Normal")
|
|
|
|
|
super(Normal, self).__init__(seed, dtype, name, param)
|
|
|
|
|
self.parameter_type = dtype
|
|
|
|
|
if mean is not None and sd is not None:
|
|
|
|
|
self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype)
|
|
|
|
|
self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype)
|
|
|
|
|
self._mean_value = convert_to_batch(mean, self.broadcast_shape, self.parameter_type)
|
|
|
|
|
self._sd_value = convert_to_batch(sd, self.broadcast_shape, self.parameter_type)
|
|
|
|
|
check_greater_zero(self._sd_value, "Standard deviation")
|
|
|
|
|
else:
|
|
|
|
|
self._mean_value = mean
|
|
|
|
@ -113,6 +114,7 @@ class Normal(Distribution):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ops needed for the class
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.const = P.ScalarToArray()
|
|
|
|
|
self.erf = P.Erf()
|
|
|
|
|
self.exp = P.Exp()
|
|
|
|
@ -141,31 +143,51 @@ class Normal(Distribution):
|
|
|
|
|
"""
|
|
|
|
|
Mean of the distribution.
|
|
|
|
|
"""
|
|
|
|
|
mean = self._mean_value if mean is None or sd is None else mean
|
|
|
|
|
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
|
|
|
|
if mean is None:
|
|
|
|
|
raise_none_error("mean")
|
|
|
|
|
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
|
|
|
if sd is None:
|
|
|
|
|
raise_none_error("sd")
|
|
|
|
|
return mean
|
|
|
|
|
|
|
|
|
|
def _mode(self, mean=None, sd=None):
|
|
|
|
|
"""
|
|
|
|
|
Mode of the distribution.
|
|
|
|
|
"""
|
|
|
|
|
mean = self._mean_value if mean is None or sd is None else mean
|
|
|
|
|
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
|
|
|
|
if mean is None:
|
|
|
|
|
raise_none_error("mean")
|
|
|
|
|
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
|
|
|
if sd is None:
|
|
|
|
|
raise_none_error("sd")
|
|
|
|
|
return mean
|
|
|
|
|
|
|
|
|
|
def _sd(self, mean=None, sd=None):
|
|
|
|
|
"""
|
|
|
|
|
Standard deviation of the distribution.
|
|
|
|
|
"""
|
|
|
|
|
sd = self._sd_value if mean is None or sd is None else sd
|
|
|
|
|
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
|
|
|
|
if mean is None:
|
|
|
|
|
raise_none_error("mean")
|
|
|
|
|
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
|
|
|
if sd is None:
|
|
|
|
|
raise_none_error("sd")
|
|
|
|
|
return sd
|
|
|
|
|
|
|
|
|
|
def _entropy(self, sd=None):
|
|
|
|
|
def _entropy(self, mean=None, sd=None):
|
|
|
|
|
r"""
|
|
|
|
|
Evaluate entropy.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
|
|
|
|
"""
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
|
|
|
|
if mean is None:
|
|
|
|
|
raise_none_error("mean")
|
|
|
|
|
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
|
|
|
if sd is None:
|
|
|
|
|
raise_none_error("sd")
|
|
|
|
|
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
|
|
|
|
|
|
|
|
|
|
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
|
|
|
@ -179,9 +201,8 @@ class Normal(Distribution):
|
|
|
|
|
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
|
|
|
|
|
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
|
|
|
|
|
"""
|
|
|
|
|
if dist == 'Normal':
|
|
|
|
|
return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
return None
|
|
|
|
|
check_distribution_name(dist, 'Normal')
|
|
|
|
|
return self._entropy(mean=mean_a, sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
|
|
|
|
|
|
|
|
|
|
def _log_prob(self, value, mean=None, sd=None):
|
|
|
|
|
r"""
|
|
|
|
@ -195,10 +216,17 @@ class Normal(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
|
|
|
|
|
"""
|
|
|
|
|
mean = self._mean_value if mean is None else mean
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
if value is None:
|
|
|
|
|
raise_none_error("value")
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
|
|
|
|
if mean is None:
|
|
|
|
|
raise_none_error("mean")
|
|
|
|
|
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
|
|
|
if sd is None:
|
|
|
|
|
raise_none_error("sd")
|
|
|
|
|
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
|
|
|
|
neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd)
|
|
|
|
|
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
|
|
|
|
|
return unnormalized_log_prob + neg_normalization
|
|
|
|
|
|
|
|
|
|
def _cdf(self, value, mean=None, sd=None):
|
|
|
|
@ -213,8 +241,15 @@ class Normal(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
|
|
|
|
|
"""
|
|
|
|
|
mean = self._mean_value if mean is None else mean
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
if value is None:
|
|
|
|
|
raise_none_error("value")
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
|
|
|
|
if mean is None:
|
|
|
|
|
raise_none_error("mean")
|
|
|
|
|
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
|
|
|
if sd is None:
|
|
|
|
|
raise_none_error("sd")
|
|
|
|
|
sqrt2 = self.sqrt(self.const(2.0))
|
|
|
|
|
adjusted = (value - mean) / (sd * sqrt2)
|
|
|
|
|
return 0.5 * (1.0 + self.erf(adjusted))
|
|
|
|
@ -234,13 +269,23 @@ class Normal(Distribution):
|
|
|
|
|
KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 +
|
|
|
|
|
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
|
|
|
|
|
"""
|
|
|
|
|
if dist == 'Normal':
|
|
|
|
|
mean_a = self._mean_value if mean_a is None else mean_a
|
|
|
|
|
sd_a = self._sd_value if sd_a is None else sd_a
|
|
|
|
|
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
|
|
|
|
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
|
|
|
|
|
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
|
|
|
|
return None
|
|
|
|
|
check_distribution_name(dist, 'Normal')
|
|
|
|
|
if mean_b is None:
|
|
|
|
|
raise_none_error("mean_b")
|
|
|
|
|
if sd_b is None:
|
|
|
|
|
raise_none_error("sd_b")
|
|
|
|
|
mean_b = self.cast(mean_b, self.parameter_type)
|
|
|
|
|
sd_b = self.cast(sd_b, self.parameter_type)
|
|
|
|
|
mean_a = self.cast(mean_a, self.parameter_type) if mean_a is not None else self._mean_value
|
|
|
|
|
sd_a = self.cast(sd_a, self.parameter_type) if sd_a is not None else self._sd_value
|
|
|
|
|
if mean_a is None:
|
|
|
|
|
raise_none_error("mean_a")
|
|
|
|
|
if sd_a is None:
|
|
|
|
|
raise_none_error("sd_a")
|
|
|
|
|
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
|
|
|
|
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
|
|
|
|
|
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sample(self, shape=(), mean=None, sd=None):
|
|
|
|
|
"""
|
|
|
|
@ -254,8 +299,12 @@ class Normal(Distribution):
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, shape is shape + batch_shape.
|
|
|
|
|
"""
|
|
|
|
|
mean = self._mean_value if mean is None else mean
|
|
|
|
|
sd = self._sd_value if sd is None else sd
|
|
|
|
|
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
|
|
|
|
if mean is None:
|
|
|
|
|
raise_none_error("mean")
|
|
|
|
|
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
|
|
|
if sd is None:
|
|
|
|
|
raise_none_error("sd")
|
|
|
|
|
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
|
|
|
|
|
sample_shape = shape + batch_shape
|
|
|
|
|
sample_norm = C.normal(sample_shape, mean, sd, self.seed)
|
|
|
|
|