Added lognormal distribuition

pull/6746/head
peixu_ren 5 years ago
parent 3eff68f8aa
commit 23ff21edd8

@ -24,6 +24,7 @@ from .exponential import Exponential
from .uniform import Uniform
from .geometric import Geometric
from .categorical import Categorical
from .log_normal import LogNormal
__all__ = ['Distribution',
'TransformedDistribution',
@ -32,4 +33,6 @@ __all__ = ['Distribution',
'Exponential',
'Uniform',
'Categorical',
'Geometric',]
'Geometric',
'LogNormal',
]

@ -76,7 +76,10 @@ class Distribution(Cell):
self._parameters[k] = param[k]
# some attributes
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
if 'distribution' in self.parameters.keys():
self.parameter_type = self.parameters['distribution'].parameter_type
else:
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
self._broadcast_shape = self._calc_broadcast_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
@ -206,8 +209,8 @@ class Distribution(Cell):
"""
Check if the parameters used during initialization are scalars.
"""
if hasattr(self, 'distribution'):
return self._distribution.is_scalar_batch
if 'distribution' in self.parameters.keys():
return self.parameters['distribution'].is_scalar_batch
param_dict = self.parameters['param_dict']
for value in param_dict.values():
if value is None:
@ -220,8 +223,8 @@ class Distribution(Cell):
"""
Calculate the broadcast shape of the parameters used during initialization.
"""
if hasattr(self, 'distribution'):
return self._distribution.broadcast_shape
if 'distribution' in self.parameters.keys():
return self.parameters['distribution'].broadcast_shape
param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None
for value in param_dict.values():

@ -0,0 +1,235 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LogNormal Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
import mindspore.nn.probability.bijector as msb
import mindspore.nn.probability.distribution as msd
from ._utils.utils import check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
class LogNormal(msd.TransformedDistribution):
"""
LogNormal distribution.
A log-normal (or lognormal) distribution is a continuous probability distribution of a random variable whose
logarithm is normally distributed. It is constructed as the exponential transformation of a Normal distribution.
Args:
loc (int, float, list, numpy.ndarray, Tensor, Parameter): The mean of the underlying Normal distribution.
scale (int, float, list, numpy.ndarray, Tensor, Parameter): The standard deviation of the underlying
Normal distribution.
seed (int): the seed used in sampling. The global seed is used if it is None. Default: None.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
name (str): the name of the distribution. Default: 'LogNormal'.
Note:
`scale` must be greater than zero.
`dist_spec_args` are `loc` and `scale`.
`dtype` must be a float type because LogNormal distributions are continuous.
Examples:
>>> # To initialize a LogNormal distribution of `loc` 3.0 and `scale` 4.0.
>>> n = msd.LogNormal(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent LogNormal distributions.
>>> n = msd.LogNormal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> # A LogNormal distribution can be initilize without arguments.
>>> # In this case, `loc` and `scale` must be passed in during function calls.
>>> n = msd.LogNormal(dtype=mstype.float32)
>>>
>>> # To use a LogNormal distribution in a network.
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.n1 = msd.LogNormal(0.0, 1.0, dtype=mstype.float32)
>>> self.n2 = msd.LogNormal(dtype=mstype.float32)
>>>
>>> # The following calls are valid in construct.
>>> def construct(self, value, loc_b, scale_b, loc_a, scale_a):
>>>
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
>>> # arguments as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>>
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' by the name of the function.
>>> ans = self.n1.prob(value)
>>> # Evaluate with respect to distribution b.
>>> ans = self.n1.prob(value, loc_b, scale_b)
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
>>> ans = self.n2.prob(value, loc_a, scale_a)
>>>
>>>
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
>>> # Args:
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>>
>>> # Example of `mean`. `sd`, `var`, and `entropy` are similar.
>>> ans = self.n1.mean() # return 0.0
>>> ans = self.n1.mean(loc_b, scale_b) # return mean_b
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
>>> ans = self.n2.mean(loc_a, scale_a)
>>>
>>>
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
>>> # Args:
>>> # dist (str): the type of the distributions. Only "Normal" is supported.
>>> # loc_b (Tensor): the loc of distribution b.
>>> # scale_b (Tensor): the scale distribution b.
>>> # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>>
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
>>> ans = self.n1.kl_loss('Normal', loc_b, scale_b)
>>> ans = self.n1.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a)
>>> # Additional `loc` and `scale` must be passed in since they were not passed in construct.
>>> ans = self.n2.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a)
>>>
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>> # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> ans = self.n1.sample()
>>> ans = self.n1.sample((2,3))
>>> ans = self.n1.sample((2,3), loc_b, scale_b)
>>> ans = self.n2.sample((2,3), loc_a, scale_a)
"""
def __init__(self,
loc=None,
scale=None,
seed=0,
dtype=mstype.float32,
name="LogNormal"):
"""
Constructor of LogNormal distribution.
"""
super(LogNormal, self).__init__(distribution=msd.Normal(loc, scale, dtype=dtype),
bijector=msb.Exp(),
dtype=dtype, seed=seed, name=name)
self.log_2pi = np.log(2 * np.pi)
#ops needed for the class
self.exp = exp_generic
self.expm1 = expm1_generic
self.log = log_generic
self.const = P.ScalarToArray()
self.erf = P.Erf()
self.fill = P.Fill()
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike()
@property
def loc(self):
"""Distribution parameter for the pre-transformed mean."""
return self.distribution("mean")
@property
def scale(self):
"""Distribution parameter for the pre-transformed standard deviation."""
return self.distribution("sd")
def extend_repr(self):
if self.is_scalar_batch:
str_info = f'loc = {self._mean_value}, scale = {self._sd_value}'
else:
str_info = f'batch_shape = {self._broadcast_shape}'
return str_info
def _mean(self, loc=None, scale=None):
"""
The mean of the distribution.
"""
mean, sd = self._check_param_type(loc, scale)
var = self.distribution("var", mean=mean, sd=sd)
return self.exp(mean + 0.5 * var)
def _mode(self, loc=None, scale=None):
"""
The mode of the distribution.
"""
mean, sd = self._check_param_type(loc, scale)
var = self.distribution("var", mean=mean, sd=sd)
return self.exp(mean - var)
def _var(self, loc=None, scale=None):
"""
The varience of the distribution.
"""
mean, sd = self._check_param_type(loc, scale)
var = self.distribution("var", mean=mean, sd=sd)
return self.expm1(var) * self.exp(2. * mean + var)
def _entropy(self, loc=None, scale=None):
r"""
Evaluate entropy.
.. math::
H(X) = μ + 0.5 + \log(σ) + 0.5 * \log(2pi)
"""
mean, sd = self._check_param_type(loc, scale)
return mean + 0.5 + self.log(sd) + 0.5 * self.log_2pi
def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
r"""
Evaluate cross entropy between lognormal distributions.
Args:
dist (str): The type of the distributions. Should be "LogNormal" in this case.
loc_b (Tensor): The loc of distribution b.
scale_b (Tensor): The scale of distribution b.
loc_a (Tensor): The loc of distribution a. Default: None.
scale_a (Tensor): The scale of distribution a. Default: None.
"""
check_distribution_name(dist, 'LogNormal')
return self._entropy(loc_a, scale_a) + self._kl_loss(dist, loc_b, scale_b, loc_a, scale_a)
def _kl_loss(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
r"""
Evaluate LogNormal-LogNormal kl divergence, i.e. KL(a||b).
Args:
dist (str): The type of the distributions. Should be "LogNormal" in this case.
loc_b (Tensor): The loc of distribution b.
scale_b (Tensor): The scale of distribution b.
loc_a (Tensor): The loc of distribution a. Default: None.
scale_a (Tensor): The scale of distribution a. Default: None.
.. math::
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
check_distribution_name(dist, 'LogNormal')
return self.distribution("kl_loss", 'Normal', loc_b, scale_b, loc_a, scale_a)

@ -30,6 +30,8 @@ class TransformedDistribution(Distribution):
Args:
bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution.
dtype (mindspore.dtype): The type of the event samples.
seed (int): The seed is used in sampling. The global seed is used if it is None.
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.
Note:
@ -98,38 +100,38 @@ class TransformedDistribution(Distribution):
def is_linear_transformation(self):
return self._is_linear_transformation
def _cdf(self, *args, **kwargs):
def _cdf(self, value, *args, **kwargs):
r"""
.. math::
Y = g(X)
P(Y <= a) = P(X <= g^{-1}(a))
"""
inverse_value = self.bijector("inverse", *args, **kwargs)
return self.distribution("cdf", inverse_value)
inverse_value = self.bijector("inverse", value)
return self.distribution("cdf", inverse_value, *args, **kwargs)
def _log_cdf(self, *args, **kwargs):
return self.log(self._cdf(*args, **kwargs))
def _log_cdf(self, value, *args, **kwargs):
return self.log(self._cdf(value, *args, **kwargs))
def _survival_function(self, *args, **kwargs):
return 1.0 - self._cdf(*args, **kwargs)
def _survival_function(self, value, *args, **kwargs):
return 1.0 - self._cdf(value, *args, **kwargs)
def _log_survival(self, *args, **kwargs):
return self.log(self._survival_function(*args, **kwargs))
def _log_survival(self, value, *args, **kwargs):
return self.log(self._survival_function(value, *args, **kwargs))
def _log_prob(self, *args, **kwargs):
def _log_prob(self, value, *args, **kwargs):
r"""
.. math::
Y = g(X)
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
"""
inverse_value = self.bijector("inverse", *args, **kwargs)
unadjust_prob = self.distribution("log_prob", inverse_value)
log_jacobian = self.bijector("inverse_log_jacobian", *args, **kwargs)
inverse_value = self.bijector("inverse", value)
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
log_jacobian = self.bijector("inverse_log_jacobian", value)
return unadjust_prob + log_jacobian
def _prob(self, *args, **kwargs):
return self.exp(self._log_prob(*args, **kwargs))
def _prob(self, value, *args, **kwargs):
return self.exp(self._log_prob(value, *args, **kwargs))
def _sample(self, *args, **kwargs):
org_sample = self.distribution("sample", *args, **kwargs)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,216 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test nn.probability.distribution.LogNormal.
"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype
from mindspore import Tensor
def test_lognormal_shape_errpr():
"""
Invalid shapes.
"""
with pytest.raises(ValueError):
msd.LogNormal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_type():
with pytest.raises(TypeError):
msd.LogNormal(0., 1., dtype=dtype.int32)
def test_name():
with pytest.raises(TypeError):
msd.LogNormal(0., 1., name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.LogNormal(0., 1., seed='seed')
def test_sd():
with pytest.raises(ValueError):
msd.LogNormal(0., 0.)
with pytest.raises(ValueError):
msd.LogNormal(0., -1.)
def test_arguments():
"""
args passing during initialization.
"""
n = msd.LogNormal()
assert isinstance(n, msd.Distribution)
n = msd.LogNormal([3.0], [4.0], dtype=dtype.float32)
assert isinstance(n, msd.Distribution)
class LogNormalProb(nn.Cell):
"""
LogNormal distribution: initialize with mean/sd.
"""
def __init__(self):
super(LogNormalProb, self).__init__()
self.lognormal = msd.LogNormal(3.0, 4.0, dtype=dtype.float32)
def construct(self, value):
prob = self.lognormal.prob(value)
log_prob = self.lognormal.log_prob(value)
cdf = self.lognormal.cdf(value)
log_cdf = self.lognormal.log_cdf(value)
sf = self.lognormal.survival_function(value)
log_sf = self.lognormal.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_lognormal_prob():
"""
Test probability functions: passing value through construct.
"""
net = LogNormalProb()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
ans = net(value)
assert isinstance(ans, Tensor)
class LogNormalProb1(nn.Cell):
"""
LogNormal distribution: initialize without mean/sd.
"""
def __init__(self):
super(LogNormalProb1, self).__init__()
self.lognormal = msd.LogNormal()
def construct(self, value, mean, sd):
prob = self.lognormal.prob(value, mean, sd)
log_prob = self.lognormal.log_prob(value, mean, sd)
cdf = self.lognormal.cdf(value, mean, sd)
log_cdf = self.lognormal.log_cdf(value, mean, sd)
sf = self.lognormal.survival_function(value, mean, sd)
log_sf = self.lognormal.log_survival(value, mean, sd)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_lognormal_prob1():
"""
Test probability functions: passing mean/sd, value through construct.
"""
net = LogNormalProb1()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
mean = Tensor([0.0], dtype=dtype.float32)
sd = Tensor([1.0], dtype=dtype.float32)
ans = net(value, mean, sd)
assert isinstance(ans, Tensor)
class LogNormalKl(nn.Cell):
"""
Test class: kl_loss of LogNormal distribution.
"""
def __init__(self):
super(LogNormalKl, self).__init__()
self.n1 = msd.LogNormal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.n2 = msd.LogNormal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a):
kl1 = self.n1.kl_loss('LogNormal', mean_b, sd_b)
kl2 = self.n2.kl_loss('LogNormal', mean_b, sd_b, mean_a, sd_a)
return kl1 + kl2
def test_kl():
"""
Test kl_loss.
"""
net = LogNormalKl()
mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
ans = net(mean_b, sd_b, mean_a, sd_a)
assert isinstance(ans, Tensor)
class LogNormalCrossEntropy(nn.Cell):
"""
Test class: cross_entropy of LogNormal distribution.
"""
def __init__(self):
super(LogNormalCrossEntropy, self).__init__()
self.n1 = msd.LogNormal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.n2 = msd.LogNormal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a):
h1 = self.n1.cross_entropy('LogNormal', mean_b, sd_b)
h2 = self.n2.cross_entropy('LogNormal', mean_b, sd_b, mean_a, sd_a)
return h1 + h2
def test_cross_entropy():
"""
Test cross entropy between LogNormal distributions.
"""
net = LogNormalCrossEntropy()
mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
ans = net(mean_b, sd_b, mean_a, sd_a)
assert isinstance(ans, Tensor)
class LogNormalBasics(nn.Cell):
"""
Test class: basic mean/sd function.
"""
def __init__(self):
super(LogNormalBasics, self).__init__()
self.n = msd.LogNormal(3.0, 4.0, dtype=dtype.float32)
def construct(self):
mean = self.n.mean()
sd = self.n.sd()
mode = self.n.mode()
entropy = self.n.entropy()
return mean + sd + mode + entropy
def test_bascis():
"""
Test mean/sd/mode/entropy functionality of LogNormal.
"""
net = LogNormalBasics()
ans = net()
assert isinstance(ans, Tensor)
class LogNormalConstruct(nn.Cell):
"""
LogNormal distribution: going through construct.
"""
def __init__(self):
super(LogNormalConstruct, self).__init__()
self.lognormal = msd.LogNormal(3.0, 4.0)
self.lognormal1 = msd.LogNormal()
def construct(self, value, mean, sd):
prob = self.lognormal('prob', value)
prob1 = self.lognormal('prob', value, mean, sd)
prob2 = self.lognormal1('prob', value, mean, sd)
return prob + prob1 + prob2
def test_lognormal_construct():
"""
Test probability function going through construct.
"""
net = LogNormalConstruct()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
mean = Tensor([0.0], dtype=dtype.float32)
sd = Tensor([1.0], dtype=dtype.float32)
ans = net(value, mean, sd)
assert isinstance(ans, Tensor)
Loading…
Cancel
Save