changed distribution api

pull/3756/head
Xun Deng 5 years ago
parent 6945eb2821
commit e87e1fc6bc

@ -34,55 +34,56 @@ class Bernoulli(Distribution):
Examples: Examples:
>>> # To initialize a Bernoulli distribution of prob 0.5 >>> # To initialize a Bernoulli distribution of prob 0.5
>>> n = nn.Bernoulli(0.5, dtype=mstype.int32) >>> import mindspore.nn.probability.distribution as msd
>>> b = msd.Bernoulli(0.5, dtype=mstype.int32)
>>> >>>
>>> # The following creates two independent Bernoulli distributions >>> # The following creates two independent Bernoulli distributions
>>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32) >>> b = msd.Bernoulli([0.5, 0.5], dtype=mstype.int32)
>>> >>>
>>> # A Bernoulli distribution can be initilized without arguments >>> # A Bernoulli distribution can be initilized without arguments
>>> # In this case, probs must be passed in through construct. >>> # In this case, probs must be passed in through args during function calls.
>>> n = nn.Bernoulli(dtype=mstype.int32) >>> b = msd.Bernoulli(dtype=mstype.int32)
>>> >>>
>>> # To use Bernoulli distribution in a network >>> # To use Bernoulli in a network
>>> class net(Cell): >>> class net(Cell):
>>> def __init__(self): >>> def __init__(self):
>>> super(net, self).__init__(): >>> super(net, self).__init__():
>>> self.b1 = nn.Bernoulli(0.5, dtype=mstype.int32) >>> self.b1 = msd.Bernoulli(0.5, dtype=mstype.int32)
>>> self.b2 = nn.Bernoulli(dtype=mstype.int32) >>> self.b2 = msd.Bernoulli(dtype=mstype.int32)
>>> >>>
>>> # All the following calls in construct are valid >>> # All the following calls in construct are valid
>>> def construct(self, value, probs_b, probs_a): >>> def construct(self, value, probs_b, probs_a):
>>> >>>
>>> # Similar calls can be made to other probability functions >>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function >>> # by replacing 'prob' with the name of the function
>>> ans = self.b1('prob', value) >>> ans = self.b1.prob(value)
>>> # Evaluate with the respect to distribution b >>> # Evaluate with the respect to distribution b
>>> ans = self.b1('prob', value, probs_b) >>> ans = self.b1.prob(value, probs_b)
>>> >>>
>>> # probs must be passed in through construct >>> # probs must be passed in during function calls
>>> ans = self.b2('prob', value, probs_a) >>> ans = self.b2.prob(value, probs_a)
>>> >>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean' >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # Will return [0.0] >>> # Will return 0.5
>>> ans = self.b1('mean') >>> ans = self.b1.mean()
>>> # Will return mean_b >>> # Will return probs_b
>>> ans = self.b1('mean', probs_b) >>> ans = self.b1.mean(probs_b)
>>> >>>
>>> # probs must be passed in through construct >>> # probs must be passed in during function calls
>>> ans = self.b2('mean', probs_a) >>> ans = self.b2.mean(probs_a)
>>> >>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar >>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b) >>> ans = self.b1.kl_loss('Bernoulli', probs_b)
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a) >>> ans = self.b1.kl_loss('Bernoulli', probs_b, probs_a)
>>> >>>
>>> # Additional probs_a must be passed in through construct >>> # Additional probs_a must be passed in through
>>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) >>> ans = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
>>> >>>
>>> # Sample Usage >>> # Sample
>>> ans = self.b1('sample') >>> ans = self.b1.sample()
>>> ans = self.b1('sample', (2,3)) >>> ans = self.b1.sample((2,3))
>>> ans = self.b1('sample', (2,3), probs_b) >>> ans = self.b1.sample((2,3), probs_b)
>>> ans = self.b2('sample', (2,3), probs_a) >>> ans = self.b2.sample((2,3), probs_a)
""" """
def __init__(self, def __init__(self,
@ -130,71 +131,61 @@ class Bernoulli(Distribution):
""" """
return self._probs return self._probs
def _mean(self, name='mean', probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(B) = probs1 MEAN(B) = probs1
""" """
if name == 'mean':
return self.probs if probs1 is None else probs1 return self.probs if probs1 is None else probs1
return None
def _mode(self, name='mode', probs1=None): def _mode(self, probs1=None):
r""" r"""
.. math:: .. math::
MODE(B) = 1 if probs1 > 0.5 else = 0 MODE(B) = 1 if probs1 > 0.5 else = 0
""" """
if name == 'mode':
probs1 = self.probs if probs1 is None else probs1 probs1 = self.probs if probs1 is None else probs1
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0) zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0) ones = self.fill(prob_type, self.shape(probs1), 1.0)
comp = self.less(0.5, probs1) comp = self.less(0.5, probs1)
return self.select(comp, ones, zeros) return self.select(comp, ones, zeros)
return None
def _var(self, name='var', probs1=None): def _var(self, probs1=None):
r""" r"""
.. math:: .. math::
VAR(B) = probs1 * probs0 VAR(B) = probs1 * probs0
""" """
if name in self._variance_functions:
probs1 = self.probs if probs1 is None else probs1 probs1 = self.probs if probs1 is None else probs1
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return probs0 * probs1 return probs0 * probs1
return None
def _entropy(self, name='entropy', probs=None): def _entropy(self, probs=None):
r""" r"""
.. math:: .. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
""" """
if name == 'entropy':
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
probs0 = 1 - probs1 probs0 = 1 - probs1
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
return None
def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): def _cross_entropy(self, dist, probs1_b, probs1_a=None):
""" """
Evaluate cross_entropy between Bernoulli distributions. Evaluate cross_entropy between Bernoulli distributions.
Args: Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case. dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b. probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs. probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
""" """
if name == 'cross_entropy' and dist == 'Bernoulli': if dist == 'Bernoulli':
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None return None
def _prob(self, name, value, probs=None): def _prob(self, value, probs=None):
r""" r"""
pmf of Bernoulli distribution. pmf of Bernoulli distribution.
Args: Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones. value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default: self.probs. probs (Tensor): probability of outcome is 1. Default: self.probs.
@ -202,18 +193,15 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1; pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0; pmf(k) = probs0 if k = 0;
""" """
if name in self._prob_functions:
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return (probs1 * value) + (probs0 * (1.0 - value)) return (probs1 * value) + (probs0 * (1.0 - value))
return None
def _cdf(self, name, value, probs=None): def _cdf(self, value, probs=None):
r""" r"""
cdf of Bernoulli distribution. cdf of Bernoulli distribution.
Args: Args:
name (str): name of the function.
value (Tensor): value to be evaluated. value (Tensor): value to be evaluated.
probs (Tensor): probability of outcome is 1. Default: self.probs. probs (Tensor): probability of outcome is 1. Default: self.probs.
@ -222,7 +210,6 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1; cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1; cdf(k) = 1 if k >=1;
""" """
if name in self._cdf_survival_functions:
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0) value = value * self.fill(prob_type, self.shape(probs1), 1.0)
@ -233,14 +220,12 @@ class Bernoulli(Distribution):
ones = self.fill(prob_type, self.shape(value), 1.0) ones = self.fill(prob_type, self.shape(value), 1.0)
less_than_zero = self.select(comp_zero, zeros, probs0) less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones) return self.select(comp_one, less_than_zero, ones)
return None
def _kl_loss(self, name, dist, probs1_b, probs1_a=None): def _kl_loss(self, dist, probs1_b, probs1_a=None):
r""" r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
Args: Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case. dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b. probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs. probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
@ -249,26 +234,24 @@ class Bernoulli(Distribution):
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b}) probs0_a * \log(\fract{probs0_a}{probs0_b})
""" """
if name in self._divergence_functions and dist == 'Bernoulli': if dist == 'Bernoulli':
probs1_a = self.probs if probs1_a is None else probs1_a probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
return None return None
def _sample(self, name, shape=(), probs=None): def _sample(self, shape=(), probs=None):
""" """
Sampling. Sampling.
Args: Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: (). shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self.probs. probs (Tensor): probs1 of the samples. Default: self.probs.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
if name == 'sample':
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
l_zero = self.const(0.0) l_zero = self.const(0.0)
h_one = self.const(1.0) h_one = self.const(1.0)
@ -276,4 +259,3 @@ class Bernoulli(Distribution):
sample = self.less(sample_uniform, probs1) sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self.dtype) sample = self.cast(sample, self.dtype)
return sample return sample
return None

@ -27,11 +27,7 @@ class Distribution(Cell):
Note: Note:
Derived class should override operations such as ,_mean, _prob, Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when and _log_prob. Arguments should be passed in through *args.
used inside a network. Arguments should be passed in through *args
in the form of function name followed by additional arguments.
Functions such as cdf and prob, require a value to be passed in while
functions such as mean and sd do not require arguments other than name.
Dist_spec_args are unique for each type of distribution. For example, mean and sd Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution. are the dist_spec_args for a Normal distribution.
@ -73,11 +69,6 @@ class Distribution(Cell):
self._set_log_survival() self._set_log_survival()
self._set_cross_entropy() self._set_cross_entropy()
self._prob_functions = ('prob', 'log_prob')
self._cdf_survival_functions = ('cdf', 'log_cdf', 'survival_function', 'log_survival')
self._variance_functions = ('var', 'sd')
self._divergence_functions = ('kl_loss', 'cross_entropy')
@property @property
def name(self): def name(self):
return self._name return self._name
@ -185,7 +176,7 @@ class Distribution(Cell):
Evaluate the log probability(pdf or pmf) at the given value. Evaluate the log probability(pdf or pmf) at the given value.
Note: Note:
Args must include name of the function and value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_prob(*args) return self._call_log_prob(*args)
@ -204,7 +195,7 @@ class Distribution(Cell):
Evaluate the probability (pdf or pmf) at given value. Evaluate the probability (pdf or pmf) at given value.
Note: Note:
Args must include name of the function and value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_prob(*args) return self._call_prob(*args)
@ -223,7 +214,7 @@ class Distribution(Cell):
Evaluate the cdf at given value. Evaluate the cdf at given value.
Note: Note:
Args must include name of the function and value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_cdf(*args) return self._call_cdf(*args)
@ -260,7 +251,7 @@ class Distribution(Cell):
Evaluate the log cdf at given value. Evaluate the log cdf at given value.
Note: Note:
Args must include name of the function and value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_cdf(*args) return self._call_log_cdf(*args)
@ -279,7 +270,7 @@ class Distribution(Cell):
Evaluate the survival function at given value. Evaluate the survival function at given value.
Note: Note:
Args must include name of the function and value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_survival(*args) return self._call_survival(*args)
@ -307,7 +298,7 @@ class Distribution(Cell):
Evaluate the log survival function at given value. Evaluate the log survival function at given value.
Note: Note:
Args must include name of the function and value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_survival(*args) return self._call_log_survival(*args)
@ -326,7 +317,7 @@ class Distribution(Cell):
Evaluate the KL divergence, i.e. KL(a||b). Evaluate the KL divergence, i.e. KL(a||b).
Note: Note:
Args must include name of the function, type of the distribution, parameters of distribution b. Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._kl_loss(*args) return self._kl_loss(*args)
@ -336,7 +327,7 @@ class Distribution(Cell):
Evaluate the mean. Evaluate the mean.
Note: Note:
Args must include the name of function. Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._mean(*args) return self._mean(*args)
@ -345,7 +336,7 @@ class Distribution(Cell):
Evaluate the mode. Evaluate the mode.
Note: Note:
Args must include the name of function. Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._mode(*args) return self._mode(*args)
@ -354,7 +345,7 @@ class Distribution(Cell):
Evaluate the standard deviation. Evaluate the standard deviation.
Note: Note:
Args must include the name of function. Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_sd(*args) return self._call_sd(*args)
@ -363,7 +354,7 @@ class Distribution(Cell):
Evaluate the variance. Evaluate the variance.
Note: Note:
Args must include the name of function. Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_var(*args) return self._call_var(*args)
@ -390,7 +381,7 @@ class Distribution(Cell):
Evaluate the entropy. Evaluate the entropy.
Note: Note:
Args must include the name of function. Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._entropy(*args) return self._entropy(*args)
@ -399,7 +390,7 @@ class Distribution(Cell):
Evaluate the cross_entropy between distribution a and b. Evaluate the cross_entropy between distribution a and b.
Note: Note:
Args must include name of the function, type of the distribution, parameters of distribution b. Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._call_cross_entropy(*args) return self._call_cross_entropy(*args)
@ -421,13 +412,13 @@ class Distribution(Cell):
*args (list): arguments passed in through construct. *args (list): arguments passed in through construct.
Note: Note:
Args must include name of the function. Shape of the sample is default to ().
Shape of the sample and dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._sample(*args) return self._sample(*args)
def construct(self, *inputs): def construct(self, name, *args):
""" """
Override construct in Cell. Override construct in Cell.
@ -437,35 +428,36 @@ class Distribution(Cell):
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'. 'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.
Args: Args:
*inputs (list): inputs[0] is always the name of the function. name (str): name of the function.
""" *args (list): list of arguments needed for the function.
"""
if inputs[0] == 'log_prob':
return self._call_log_prob(*inputs) if name == 'log_prob':
if inputs[0] == 'prob': return self._call_log_prob(*args)
return self._call_prob(*inputs) if name == 'prob':
if inputs[0] == 'cdf': return self._call_prob(*args)
return self._call_cdf(*inputs) if name == 'cdf':
if inputs[0] == 'log_cdf': return self._call_cdf(*args)
return self._call_log_cdf(*inputs) if name == 'log_cdf':
if inputs[0] == 'survival_function': return self._call_log_cdf(*args)
return self._call_survival(*inputs) if name == 'survival_function':
if inputs[0] == 'log_survival': return self._call_survival(*args)
return self._call_log_survival(*inputs) if name == 'log_survival':
if inputs[0] == 'kl_loss': return self._call_log_survival(*args)
return self._kl_loss(*inputs) if name == 'kl_loss':
if inputs[0] == 'mean': return self._kl_loss(*args)
return self._mean(*inputs) if name == 'mean':
if inputs[0] == 'mode': return self._mean(*args)
return self._mode(*inputs) if name == 'mode':
if inputs[0] == 'sd': return self._mode(*args)
return self._call_sd(*inputs) if name == 'sd':
if inputs[0] == 'var': return self._call_sd(*args)
return self._call_var(*inputs) if name == 'var':
if inputs[0] == 'entropy': return self._call_var(*args)
return self._entropy(*inputs) if name == 'entropy':
if inputs[0] == 'cross_entropy': return self._entropy(*args)
return self._call_cross_entropy(*inputs) if name == 'cross_entropy':
if inputs[0] == 'sample': return self._call_cross_entropy(*args)
return self._sample(*inputs) if name == 'sample':
return self._sample(*args)
return None return None

@ -35,55 +35,56 @@ class Exponential(Distribution):
Examples: Examples:
>>> # To initialize an Exponential distribution of rate 0.5 >>> # To initialize an Exponential distribution of rate 0.5
>>> n = nn.Exponential(0.5, dtype=mstype.float32) >>> import mindspore.nn.probability.distribution as msd
>>> e = msd.Exponential(0.5, dtype=mstype.float32)
>>> >>>
>>> # The following creates two independent Exponential distributions >>> # The following creates two independent Exponential distributions
>>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32) >>> e = msd.Exponential([0.5, 0.5], dtype=mstype.float32)
>>> >>>
>>> # A Exponential distribution can be initilized without arguments >>> # An Exponential distribution can be initilized without arguments
>>> # In this case, rate must be passed in through construct. >>> # In this case, rate must be passed in through args during function calls
>>> n = nn.Exponential(dtype=mstype.float32) >>> e = msd.Exponential(dtype=mstype.float32)
>>> >>>
>>> # To use Exponential distribution in a network >>> # To use Exponential in a network
>>> class net(Cell): >>> class net(Cell):
>>> def __init__(self): >>> def __init__(self):
>>> super(net, self).__init__(): >>> super(net, self).__init__():
>>> self.e1 = nn.Exponential(0.5, dtype=mstype.float32) >>> self.e1 = msd.Exponential(0.5, dtype=mstype.float32)
>>> self.e2 = nn.Exponential(dtype=mstype.float32) >>> self.e2 = msd.Exponential(dtype=mstype.float32)
>>> >>>
>>> # All the following calls in construct are valid >>> # All the following calls in construct are valid
>>> def construct(self, value, rate_b, rate_a): >>> def construct(self, value, rate_b, rate_a):
>>> >>>
>>> # Similar calls can be made to other probability functions >>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function >>> # by replacing 'prob' with the name of the function
>>> ans = self.e1('prob', value) >>> ans = self.e1.prob(value)
>>> # Evaluate with the respect to distribution b >>> # Evaluate with the respect to distribution b
>>> ans = self.e1('prob', value, rate_b) >>> ans = self.e1.prob(value, rate_b)
>>> >>>
>>> # Rate must be passed in through construct >>> # Rate must be passed in during function calls
>>> ans = self.e2('prob', value, rate_a) >>> ans = self.e2.prob(value, rate_a)
>>> >>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' >>> # Functions 'sd', 'var', 'entropy' have the same usage as'mean'
>>> # Will return [0.0] >>> # Will return 2
>>> ans = self.e1('mean') >>> ans = self.e1.mean()
>>> # Will return mean_b >>> # Will return 1 / rate_b
>>> ans = self.e1('mean', rate_b) >>> ans = self.e1.mean(rate_b)
>>> >>>
>>> # Rate must be passed in through construct >>> # Rate must be passed in during function calls
>>> ans = self.e2('mean', rate_a) >>> ans = self.e2.mean(rate_a)
>>> >>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar >>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.e1('kl_loss', 'Exponential', rate_b) >>> ans = self.e1.kl_loss('Exponential', rate_b)
>>> ans = self.e1('kl_loss', 'Exponential', rate_b, rate_a) >>> ans = self.e1.kl_loss('Exponential', rate_b, rate_a)
>>> >>>
>>> # Additional rate must be passed in through construct >>> # Additional rate must be passed in
>>> ans = self.e2('kl_loss', 'Exponential', rate_b, rate_a) >>> ans = self.e2.kl_loss('Exponential', rate_b, rate_a)
>>> >>>
>>> # Sample Usage >>> # Sample
>>> ans = self.e1('sample') >>> ans = self.e1.sample()
>>> ans = self.e1('sample', (2,3)) >>> ans = self.e1.sample((2,3))
>>> ans = self.e1('sample', (2,3), rate_b) >>> ans = self.e1.sample((2,3), rate_b)
>>> ans = self.e2('sample', (2,3), rate_a) >>> ans = self.e2.sample((2,3), rate_a)
""" """
def __init__(self, def __init__(self,
@ -131,67 +132,59 @@ class Exponential(Distribution):
""" """
return self._rate return self._rate
def _mean(self, name='mean', rate=None): def _mean(self, rate=None):
r""" r"""
.. math:: .. math::
MEAN(EXP) = \fract{1.0}{\lambda}. MEAN(EXP) = \fract{1.0}{\lambda}.
""" """
if name == 'mean':
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
return 1.0 / rate return 1.0 / rate
return None
def _mode(self, name='mode', rate=None):
def _mode(self, rate=None):
r""" r"""
.. math:: .. math::
MODE(EXP) = 0. MODE(EXP) = 0.
""" """
if name == 'mode':
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
return self.fill(self.dtype, self.shape(rate), 0.) return self.fill(self.dtype, self.shape(rate), 0.)
return None
def _sd(self, name='sd', rate=None): def _sd(self, rate=None):
r""" r"""
.. math:: .. math::
sd(EXP) = \fract{1.0}{\lambda}. sd(EXP) = \fract{1.0}{\lambda}.
""" """
if name in self._variance_functions:
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
return 1.0 / rate return 1.0 / rate
return None
def _entropy(self, name='entropy', rate=None): def _entropy(self, rate=None):
r""" r"""
.. math:: .. math::
H(Exp) = 1 - \log(\lambda). H(Exp) = 1 - \log(\lambda).
""" """
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
if name == 'entropy':
return 1.0 - self.log(rate) return 1.0 - self.log(rate)
return None
def _cross_entropy(self, name, dist, rate_b, rate_a=None):
def _cross_entropy(self, dist, rate_b, rate_a=None):
""" """
Evaluate cross_entropy between Exponential distributions. Evaluate cross_entropy between Exponential distributions.
Args: Args:
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Exponential" in this case. dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b. rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): rate of distribution a. Default: self.rate.
""" """
if name == 'cross_entropy' and dist == 'Exponential': if dist == 'Exponential':
return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a) return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)
return None return None
def _prob(self, name, value, rate=None): def _prob(self, value, rate=None):
r""" r"""
pdf of Exponential distribution. pdf of Exponential distribution.
Args: Args:
Args: Args:
name (str): name of the function.
value (Tensor): value to be evaluated. value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate. rate (Tensor): rate of the distribution. Default: self.rate.
@ -201,20 +194,17 @@ class Exponential(Distribution):
.. math:: .. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
if name in self._prob_functions:
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
prob = rate * self.exp(-1. * rate * value) prob = rate * self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, prob) return self.select(comp, zeros, prob)
return None
def _cdf(self, name, value, rate=None): def _cdf(self, value, rate=None):
r""" r"""
cdf of Exponential distribution. cdf of Exponential distribution.
Args: Args:
name (str): name of the function.
value (Tensor): value to be evaluated. value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate. rate (Tensor): rate of the distribution. Default: self.rate.
@ -224,45 +214,40 @@ class Exponential(Distribution):
.. math:: .. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
if name in self._cdf_survival_functions:
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
cdf = 1.0 - self.exp(-1. * rate * value) cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, cdf) return self.select(comp, zeros, cdf)
return None
def _kl_loss(self, name, dist, rate_b, rate_a=None):
def _kl_loss(self, dist, rate_b, rate_a=None):
""" """
Evaluate exp-exp kl divergence, i.e. KL(a||b). Evaluate exp-exp kl divergence, i.e. KL(a||b).
Args: Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Exponential" in this case. dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b. rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): rate of distribution a. Default: self.rate.
""" """
if name in self._divergence_functions and dist == 'Exponential': if dist == 'Exponential':
rate_a = self.rate if rate_a is None else rate_a rate_a = self.rate if rate_a is None else rate_a
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
return None return None
def _sample(self, name, shape=(), rate=None): def _sample(self, shape=(), rate=None):
""" """
Sampling. Sampling.
Args: Args:
name (str): name of the function.
shape (tuple): shape of the sample. Default: (). shape (tuple): shape of the sample. Default: ().
rate (Tensor): rate of the distribution. Default: self.rate. rate (Tensor): rate of the distribution. Default: self.rate.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
if name == 'sample':
rate = self.rate if rate is None else rate rate = self.rate if rate is None else rate
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample = self.uniform(shape + self.shape(rate), minval, maxval) sample = self.uniform(shape + self.shape(rate), minval, maxval)
return -self.log(sample) / rate return -self.log(sample) / rate
return None

@ -36,55 +36,56 @@ class Geometric(Distribution):
Examples: Examples:
>>> # To initialize a Geometric distribution of prob 0.5 >>> # To initialize a Geometric distribution of prob 0.5
>>> n = nn.Geometric(0.5, dtype=mstype.int32) >>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Geometric(0.5, dtype=mstype.int32)
>>> >>>
>>> # The following creates two independent Geometric distributions >>> # The following creates two independent Geometric distributions
>>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32) >>> n = msd.Geometric([0.5, 0.5], dtype=mstype.int32)
>>> >>>
>>> # A Geometric distribution can be initilized without arguments >>> # A Geometric distribution can be initilized without arguments
>>> # In this case, probs must be passed in through construct. >>> # In this case, probs must be passed in through args during function calls.
>>> n = nn.Geometric(dtype=mstype.int32) >>> n = msd.Geometric(dtype=mstype.int32)
>>> >>>
>>> # To use Geometric distribution in a network >>> # To use Geometric in a network
>>> class net(Cell): >>> class net(Cell):
>>> def __init__(self): >>> def __init__(self):
>>> super(net, self).__init__(): >>> super(net, self).__init__():
>>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32) >>> self.g1 = msd.Geometric(0.5, dtype=mstype.int32)
>>> self.g2 = nn.Geometric(dtype=mstype.int32) >>> self.g2 = msd.Geometric(dtype=mstype.int32)
>>> >>>
>>> # Tthe following calls are valid in construct >>> # Tthe following calls are valid in construct
>>> def construct(self, value, probs_b, probs_a): >>> def construct(self, value, probs_b, probs_a):
>>> >>>
>>> # Similar calls can be made to other probability functions >>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function >>> # by replacing 'prob' with the name of the function
>>> ans = self.g1('prob', value) >>> ans = self.g1.prob(value)
>>> # Evaluate with the respect to distribution b >>> # Evaluate with the respect to distribution b
>>> ans = self.g1('prob', value, probs_b) >>> ans = self.g1.prob(value, probs_b)
>>> >>>
>>> # Probs must be passed in through construct >>> # Probs must be passed in during function calls
>>> ans = self.g2('prob', value, probs_a) >>> ans = self.g2.prob(value, probs_a)
>>> >>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # Will return [0.0] >>> # Will return 1.0
>>> ans = self.g1('mean') >>> ans = self.g1.mean()
>>> # Will return mean_b >>> # Another possible usage
>>> ans = self.g1('mean', probs_b) >>> ans = self.g1.mean(probs_b)
>>> >>>
>>> # Probs must be passed in through construct >>> # Probs must be passed in during function calls
>>> ans = self.g2('mean', probs_a) >>> ans = self.g2.mean(probs_a)
>>> >>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar >>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.g1('kl_loss', 'Geometric', probs_b) >>> ans = self.g1.kl_loss('Geometric', probs_b)
>>> ans = self.g1('kl_loss', 'Geometric', probs_b, probs_a) >>> ans = self.g1.kl_loss('Geometric', probs_b, probs_a)
>>> >>>
>>> # Additional probs must be passed in through construct >>> # Additional probs must be passed in
>>> ans = self.g2('kl_loss', 'Geometric', probs_b, probs_a) >>> ans = self.g2.kl_loss('Geometric', probs_b, probs_a)
>>> >>>
>>> # Sample Usage >>> # Sample
>>> ans = self.g1('sample') >>> ans = self.g1.sample()
>>> ans = self.g1('sample', (2,3)) >>> ans = self.g1.sample((2,3))
>>> ans = self.g1('sample', (2,3), probs_b) >>> ans = self.g1.sample((2,3), probs_b)
>>> ans = self.g2('sample', (2,3), probs_a) >>> ans = self.g2.sample((2,3), probs_a)
""" """
def __init__(self, def __init__(self,
@ -134,67 +135,57 @@ class Geometric(Distribution):
""" """
return self._probs return self._probs
def _mean(self, name='mean', probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1} MEAN(Geo) = \fratc{1 - probs1}{probs1}
""" """
if name == 'mean':
probs1 = self.probs if probs1 is None else probs1 probs1 = self.probs if probs1 is None else probs1
return (1. - probs1) / probs1 return (1. - probs1) / probs1
return None
def _mode(self, name='mode', probs1=None): def _mode(self, probs1=None):
r""" r"""
.. math:: .. math::
MODE(Geo) = 0 MODE(Geo) = 0
""" """
if name == 'mode':
probs1 = self.probs if probs1 is None else probs1 probs1 = self.probs if probs1 is None else probs1
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
return None
def _var(self, name='var', probs1=None): def _var(self, probs1=None):
r""" r"""
.. math:: .. math::
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}} VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
""" """
if name in self._variance_functions:
probs1 = self.probs if probs1 is None else probs1 probs1 = self.probs if probs1 is None else probs1
return (1.0 - probs1) / self.sq(probs1) return (1.0 - probs1) / self.sq(probs1)
return None
def _entropy(self, name='entropy', probs=None): def _entropy(self, probs=None):
r""" r"""
.. math:: .. math::
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
""" """
if name == 'entropy':
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
return None
def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): def _cross_entropy(self, dist, probs1_b, probs1_a=None):
r""" r"""
Evaluate cross_entropy between Geometric distributions. Evaluate cross_entropy between Geometric distributions.
Args: Args:
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Geometric" in this case. dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b. probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs. probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
""" """
if name == 'cross_entropy' and dist == 'Geometric': if dist == 'Geometric':
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None return None
def _prob(self, name, value, probs=None): def _prob(self, value, probs=None):
r""" r"""
pmf of Geometric distribution. pmf of Geometric distribution.
Args: Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only natural numbers. value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs. probs (Tensor): probability of success. Default: self.probs.
@ -202,7 +193,6 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0. pmf(k) = 0 if k < 0.
""" """
if name in self._prob_functions:
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
dtype = self.dtypeop(value) dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_): if self.issubclass(dtype, mstype.int_):
@ -215,14 +205,12 @@ class Geometric(Distribution):
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, pmf) return self.select(comp, zeros, pmf)
return None
def _cdf(self, name, value, probs=None): def _cdf(self, value, probs=None):
r""" r"""
cdf of Geometric distribution. cdf of Geometric distribution.
Args: Args:
name (str): name of the function.
value (Tensor): a Tensor composed of only natural numbers. value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs. probs (Tensor): probability of success. Default: self.probs.
@ -231,7 +219,6 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0. cdf(k) = 0 if k < 0.
""" """
if name in self._cdf_survival_functions:
probs1 = self.probs if probs is None else probs probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
dtype = self.dtypeop(value) dtype = self.dtypeop(value)
@ -245,14 +232,13 @@ class Geometric(Distribution):
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, cdf) return self.select(comp, zeros, cdf)
return None
def _kl_loss(self, name, dist, probs1_b, probs1_a=None):
def _kl_loss(self, dist, probs1_b, probs1_a=None):
r""" r"""
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
Args: Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Geometric" in this case. dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b. probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs. probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
@ -260,29 +246,26 @@ class Geometric(Distribution):
.. math:: .. math::
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b}) KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
""" """
if name in self._divergence_functions and dist == 'Geometric': if dist == 'Geometric':
probs1_a = self.probs if probs1_a is None else probs1_a probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
return None return None
def _sample(self, name, shape=(), probs=None): def _sample(self, shape=(), probs=None):
""" """
Sampling. Sampling.
Args: Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: (). shape (tuple): shape of the sample. Default: ().
probs (Tensor): probability of success. Default: self.probs. probs (Tensor): probability of success. Default: self.probs.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
if name == 'sample':
probs = self.probs if probs is None else probs probs = self.probs if probs is None else probs
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
return None

@ -17,7 +17,6 @@ import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.context import get_context
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero from ._utils.utils import convert_to_batch, check_greater_equal_zero
@ -39,55 +38,56 @@ class Normal(Distribution):
Examples: Examples:
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0 >>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) >>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32)
>>> >>>
>>> # The following creates two independent Normal distributions >>> # The following creates two independent Normal distributions
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) >>> n = msd.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>> >>>
>>> # A normal distribution can be initilize without arguments >>> # A Normal distribution can be initilize without arguments
>>> # In this case, mean and sd must be passed in through construct. >>> # In this case, mean and sd must be passed in through args.
>>> n = nn.Normal(dtype=mstype.float32) >>> n = msd.Normal(dtype=mstype.float32)
>>> >>>
>>> # To use normal in a network >>> # To use Normal in a network
>>> class net(Cell): >>> class net(Cell):
>>> def __init__(self): >>> def __init__(self):
>>> super(net, self).__init__(): >>> super(net, self).__init__():
>>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32) >>> self.n1 = msd.Nomral(0.0, 1.0, dtype=mstype.float32)
>>> self.n2 = nn.Normal(dtype=mstype.float32) >>> self.n2 = msd.Normal(dtype=mstype.float32)
>>> >>>
>>> # The following calls are valid in construct >>> # The following calls are valid in construct
>>> def construct(self, value, mean_b, sd_b, mean_a, sd_a): >>> def construct(self, value, mean_b, sd_b, mean_a, sd_a):
>>> >>>
>>> # Similar calls can be made to other probability functions >>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function >>> # by replacing 'prob' with the name of the function
>>> ans = self.n1('prob', value) >>> ans = self.n1.prob(value)
>>> # Evaluate with the respect to distribution b >>> # Evaluate with the respect to distribution b
>>> ans = self.n1('prob', value, mean_b, sd_b) >>> ans = self.n1.prob(value, mean_b, sd_b)
>>> >>>
>>> # mean and sd must be passed in through construct >>> # mean and sd must be passed in during function calls
>>> ans = self.n2('prob', value, mean_a, sd_a) >>> ans = self.n2.prob(value, mean_a, sd_a)
>>> >>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # Will return [0.0] >>> # will return [0.0]
>>> ans = self.n1('mean') >>> ans = self.n1.mean()
>>> # Will return mean_b >>> # will return mean_b
>>> ans = self.n1('mean', mean_b, sd_b) >>> ans = self.n1.mean(mean_b, sd_b)
>>> >>>
>>> # mean and sd must be passed in through construct >>> # mean and sd must be passed during function calls
>>> ans = self.n2('mean', mean_a, sd_a) >>> ans = self.n2.mean(mean_a, sd_a)
>>> >>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar >>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b) >>> ans = self.n1.kl_loss('Normal', mean_b, sd_b)
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) >>> ans = self.n1.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
>>> >>>
>>> # Additional mean and sd must be passed in through construct >>> # Additional mean and sd must be passed
>>> ans = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) >>> ans = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
>>> >>>
>>> # Sample Usage >>> # Sample
>>> ans = self.n1('sample') >>> ans = self.n1.sample()
>>> ans = self.n1('sample', (2,3)) >>> ans = self.n1.sample((2,3))
>>> ans = self.n1('sample', (2,3), mean_b, sd_b) >>> ans = self.n1.sample((2,3), mean_b, sd_b)
>>> ans = self.n2('sample', (2,3), mean_a, sd_a) >>> ans = self.n2.sample((2,3), mean_a, sd_a)
""" """
def __init__(self, def __init__(self,
@ -114,7 +114,7 @@ class Normal(Distribution):
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.erf = P.Erf() self.erf = P.Erf()
self.exp = P.Exp() self.exp = P.Exp()
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step self.expm1 = self._expm1_by_step
self.fill = P.Fill() self.fill = P.Fill()
self.log = P.Log() self.log = P.Log()
self.shape = P.Shape() self.shape = P.Shape()
@ -135,67 +135,57 @@ class Normal(Distribution):
""" """
return self.exp(x) - 1.0 return self.exp(x) - 1.0
def _mean(self, name='mean', mean=None, sd=None): def _mean(self, mean=None, sd=None):
""" """
Mean of the distribution. Mean of the distribution.
""" """
if name == 'mean':
mean = self._mean_value if mean is None or sd is None else mean mean = self._mean_value if mean is None or sd is None else mean
return mean return mean
return None
def _mode(self, name='mode', mean=None, sd=None): def _mode(self, mean=None, sd=None):
""" """
Mode of the distribution. Mode of the distribution.
""" """
if name == 'mode':
mean = self._mean_value if mean is None or sd is None else mean mean = self._mean_value if mean is None or sd is None else mean
return mean return mean
return None
def _sd(self, name='sd', mean=None, sd=None): def _sd(self, mean=None, sd=None):
""" """
Standard deviation of the distribution. Standard deviation of the distribution.
""" """
if name in self._variance_functions:
sd = self._sd_value if mean is None or sd is None else sd sd = self._sd_value if mean is None or sd is None else sd
return sd return sd
return None
def _entropy(self, name='entropy', sd=None): def _entropy(self, sd=None):
r""" r"""
Evaluate entropy. Evaluate entropy.
.. math:: .. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
""" """
if name == 'entropy':
sd = self._sd_value if sd is None else sd sd = self._sd_value if sd is None else sd
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
return None
def _cross_entropy(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
r""" r"""
Evaluate cross_entropy between normal distributions. Evaluate cross_entropy between normal distributions.
Args: Args:
name (str): name of the funtion passed in from construct. Should always be "cross_entropy".
dist (str): type of the distributions. Should be "Normal" in this case. dist (str): type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b. mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b. sd_b (Tensor): standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value. mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
""" """
if name == 'cross_entropy' and dist == 'Normal': if dist == 'Normal':
return self._entropy(sd=sd_a) + self._kl_loss(name, dist, mean_b, sd_b, mean_a, sd_a) return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
return None return None
def _log_prob(self, name, value, mean=None, sd=None): def _log_prob(self, value, mean=None, sd=None):
r""" r"""
Evaluate log probability. Evaluate log probability.
Args: Args:
name (str): name of the funtion passed in from construct.
value (Tensor): value to be evaluated. value (Tensor): value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value. mean (Tensor): mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value. sd (Tensor): standard deviation the distribution. Default: self._sd_value.
@ -203,20 +193,17 @@ class Normal(Distribution):
.. math:: .. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
""" """
if name in self._prob_functions:
mean = self._mean_value if mean is None else mean mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd sd = self._sd_value if sd is None else sd
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
return unnormalized_log_prob + neg_normalization return unnormalized_log_prob + neg_normalization
return None
def _cdf(self, name, value, mean=None, sd=None): def _cdf(self, value, mean=None, sd=None):
r""" r"""
Evaluate cdf of given value. Evaluate cdf of given value.
Args: Args:
name (str): name of the funtion passed in from construct. Should always be "cdf".
value (Tensor): value to be evaluated. value (Tensor): value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value. mean (Tensor): mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value. sd (Tensor): standard deviation the distribution. Default: self._sd_value.
@ -224,20 +211,17 @@ class Normal(Distribution):
.. math:: .. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
""" """
if name in self._cdf_survival_functions:
mean = self._mean_value if mean is None else mean mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd sd = self._sd_value if sd is None else sd
sqrt2 = self.sqrt(self.const(2.0)) sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2) adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted)) return 0.5 * (1.0 + self.erf(adjusted))
return None
def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
r""" r"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b). Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Args: Args:
name (str): name of the funtion passed in from construct.
dist (str): type of the distributions. Should be "Normal" in this case. dist (str): type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b. mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b. sd_b (Tensor): standard deviation distribution b.
@ -248,7 +232,7 @@ class Normal(Distribution):
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
""" """
if name in self._divergence_functions and dist == 'Normal': if dist == 'Normal':
mean_a = self._mean_value if mean_a is None else mean_a mean_a = self._mean_value if mean_a is None else mean_a
sd_a = self._sd_value if sd_a is None else sd_a sd_a = self._sd_value if sd_a is None else sd_a
diff_log_scale = self.log(sd_a) - self.log(sd_b) diff_log_scale = self.log(sd_a) - self.log(sd_b)
@ -256,12 +240,11 @@ class Normal(Distribution):
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
return None return None
def _sample(self, name, shape=(), mean=None, sd=None): def _sample(self, shape=(), mean=None, sd=None):
""" """
Sampling. Sampling.
Args: Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: (). shape (tuple): shape of the sample. Default: ().
mean (Tensor): mean of the samples. Default: self._mean_value. mean (Tensor): mean of the samples. Default: self._mean_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value. sd (Tensor): standard deviation of the samples. Default: self._sd_value.
@ -269,7 +252,6 @@ class Normal(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
if name == 'sample':
mean = self._mean_value if mean is None else mean mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
@ -279,4 +261,3 @@ class Normal(Distribution):
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = mean + sample_norm * sd sample = mean + sample_norm * sd
return sample return sample
return None

@ -35,55 +35,56 @@ class Uniform(Distribution):
Examples: Examples:
>>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0 >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32) >>> import mindspore.nn.probability.distribution as msd
>>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> >>>
>>> # The following creates two independent Uniform distributions >>> # The following creates two independent Uniform distributions
>>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32) >>> u = msd.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
>>> >>>
>>> # A Uniform distribution can be initilized without arguments >>> # A Uniform distribution can be initilized without arguments
>>> # In this case, high and low must be passed in through construct. >>> # In this case, high and low must be passed in through args during function calls.
>>> n = nn.Uniform(dtype=mstype.float32) >>> u = msd.Uniform(dtype=mstype.float32)
>>> >>>
>>> # To use Uniform in a network >>> # To use Uniform in a network
>>> class net(Cell): >>> class net(Cell):
>>> def __init__(self) >>> def __init__(self)
>>> super(net, self).__init__(): >>> super(net, self).__init__():
>>> self.u1 = nn.Uniform(0.0, 1.0, dtype=mstype.float32) >>> self.u1 = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> self.u2 = nn.Uniform(dtype=mstype.float32) >>> self.u2 = msd.Uniform(dtype=mstype.float32)
>>> >>>
>>> # All the following calls in construct are valid >>> # All the following calls in construct are valid
>>> def construct(self, value, low_b, high_b, low_a, high_a): >>> def construct(self, value, low_b, high_b, low_a, high_a):
>>> >>>
>>> # Similar calls can be made to other probability functions >>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function >>> # by replacing 'prob' with the name of the function
>>> ans = self.u1('prob', value) >>> ans = self.u1.prob(value)
>>> # Evaluate with the respect to distribution b >>> # Evaluate with the respect to distribution b
>>> ans = self.u1('prob', value, low_b, high_b) >>> ans = self.u1.prob(value, low_b, high_b)
>>> >>>
>>> # High and low must be passed in through construct >>> # High and low must be passed in during function calls
>>> ans = self.u2('prob', value, low_a, high_a) >>> ans = self.u2.prob(value, low_a, high_a)
>>> >>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' >>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
>>> # Will return [0.0] >>> # Will return 0.5
>>> ans = self.u1('mean') >>> ans = self.u1.mean()
>>> # Will return low_b >>> # Will return (low_b + high_b) / 2
>>> ans = self.u1('mean', low_b, high_b) >>> ans = self.u1.mean(low_b, high_b)
>>> >>>
>>> # High and low must be passed in through construct >>> # High and low must be passed in during function calls
>>> ans = self.u2('mean', low_a, high_a) >>> ans = self.u2.mean(low_a, high_a)
>>> >>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar >>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b) >>> ans = self.u1.kl_loss('Uniform', low_b, high_b)
>>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) >>> ans = self.u1.kl_loss('Uniform', low_b, high_b, low_a, high_a)
>>> >>>
>>> # Additional high and low must be passed in through construct >>> # Additional high and low must be passed
>>> ans = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) >>> ans = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
>>> >>>
>>> # Sample Usage >>> # Sample
>>> ans = self.u1('sample') >>> ans = self.u1.sample()
>>> ans = self.u1('sample', (2,3)) >>> ans = self.u1.sample((2,3))
>>> ans = self.u1('sample', (2,3), low_b, high_b) >>> ans = self.u1.sample((2,3), low_b, high_b)
>>> ans = self.u2('sample', (2,3), low_a, high_a) >>> ans = self.u2.sample((2,3), low_a, high_a)
""" """
def __init__(self, def __init__(self,
@ -142,73 +143,64 @@ class Uniform(Distribution):
""" """
return self._high return self._high
def _range(self, name='range', low=None, high=None): def _range(self, low=None, high=None):
r""" r"""
Return the range of the distribution. Return the range of the distribution.
.. math:: .. math::
range(U) = high -low range(U) = high -low
""" """
if name == 'range':
low = self.low if low is None else low low = self.low if low is None else low
high = self.high if high is None else high high = self.high if high is None else high
return high - low return high - low
return None
def _mean(self, name='mean', low=None, high=None): def _mean(self, low=None, high=None):
r""" r"""
.. math:: .. math::
MEAN(U) = \fract{low + high}{2}. MEAN(U) = \fract{low + high}{2}.
""" """
if name == 'mean':
low = self.low if low is None else low low = self.low if low is None else low
high = self.high if high is None else high high = self.high if high is None else high
return (low + high) / 2. return (low + high) / 2.
return None
def _var(self, name='var', low=None, high=None):
def _var(self, low=None, high=None):
r""" r"""
.. math:: .. math::
VAR(U) = \fract{(high -low) ^ 2}{12}. VAR(U) = \fract{(high -low) ^ 2}{12}.
""" """
if name in self._variance_functions:
low = self.low if low is None else low low = self.low if low is None else low
high = self.high if high is None else high high = self.high if high is None else high
return self.sq(high - low) / 12.0 return self.sq(high - low) / 12.0
return None
def _entropy(self, name='entropy', low=None, high=None): def _entropy(self, low=None, high=None):
r""" r"""
.. math:: .. math::
H(U) = \log(high - low). H(U) = \log(high - low).
""" """
if name == 'entropy':
low = self.low if low is None else low low = self.low if low is None else low
high = self.high if high is None else high high = self.high if high is None else high
return self.log(high - low) return self.log(high - low)
return None
def _cross_entropy(self, name, dist, low_b, high_b, low_a=None, high_a=None): def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
""" """
Evaluate cross_entropy between Uniform distributoins. Evaluate cross_entropy between Uniform distributoins.
Args: Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Uniform" in this case. dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b. low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b. high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low. low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): upper bound of distribution a. Default: self.high.
""" """
if name == 'cross_entropy' and dist == 'Uniform': if dist == 'Uniform':
return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a) return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
return None return None
def _prob(self, name, value, low=None, high=None): def _prob(self, value, low=None, high=None):
r""" r"""
pdf of Uniform distribution. pdf of Uniform distribution.
Args: Args:
name (str): name of the function.
value (Tensor): value to be evaluated. value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low. low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high. high (Tensor): upper bound of the distribution. Default: self.high.
@ -218,7 +210,6 @@ class Uniform(Distribution):
pdf(x) = \fract{1.0}{high -low} if low <= x <= high; pdf(x) = \fract{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high; pdf(x) = 0 if x > high;
""" """
if name in self._prob_functions:
low = self.low if low is None else low low = self.low if low is None else low
high = self.high if high is None else high high = self.high if high is None else high
ones = self.fill(self.dtype, self.shape(value), 1.0) ones = self.fill(self.dtype, self.shape(value), 1.0)
@ -229,21 +220,19 @@ class Uniform(Distribution):
comp_hi = self.lessequal(value, high) comp_hi = self.lessequal(value, high)
less_than_low = self.select(comp_lo, zeros, prob) less_than_low = self.select(comp_lo, zeros, prob)
return self.select(comp_hi, less_than_low, zeros) return self.select(comp_hi, less_than_low, zeros)
return None
def _kl_loss(self, name, dist, low_b, high_b, low_a=None, high_a=None): def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None):
""" """
Evaluate uniform-uniform kl divergence, i.e. KL(a||b). Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
Args: Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Uniform" in this case. dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b. low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b. high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low. low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): upper bound of distribution a. Default: self.high.
""" """
if name in self._divergence_functions and dist == 'Uniform': if dist == 'Uniform':
low_a = self.low if low_a is None else low_a low_a = self.low if low_a is None else low_a
high_a = self.high if high_a is None else high_a high_a = self.high if high_a is None else high_a
kl = self.log(high_b - low_b) / self.log(high_a - low_a) kl = self.log(high_b - low_b) / self.log(high_a - low_a)
@ -251,12 +240,11 @@ class Uniform(Distribution):
return self.select(comp, kl, self.log(self.zeroslike(kl))) return self.select(comp, kl, self.log(self.zeroslike(kl)))
return None return None
def _cdf(self, name, value, low=None, high=None): def _cdf(self, value, low=None, high=None):
r""" r"""
cdf of Uniform distribution. cdf of Uniform distribution.
Args: Args:
name (str): name of the function.
value (Tensor): value to be evaluated. value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low. low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high. high (Tensor): upper bound of the distribution. Default: self.high.
@ -266,7 +254,6 @@ class Uniform(Distribution):
cdf(x) = \fract{x - low}{high -low} if low <= x <= high; cdf(x) = \fract{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high; cdf(x) = 1 if x > high;
""" """
if name in self._cdf_survival_functions:
low = self.low if low is None else low low = self.low if low is None else low
high = self.high if high is None else high high = self.high if high is None else high
prob = (value - low) / (high - low) prob = (value - low) / (high - low)
@ -277,14 +264,12 @@ class Uniform(Distribution):
comp_hi = self.less(value, high) comp_hi = self.less(value, high)
less_than_low = self.select(comp_lo, zeros, prob) less_than_low = self.select(comp_lo, zeros, prob)
return self.select(comp_hi, less_than_low, ones) return self.select(comp_hi, less_than_low, ones)
return None
def _sample(self, name, shape=(), low=None, high=None): def _sample(self, shape=(), low=None, high=None):
""" """
Sampling. Sampling.
Args: Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: (). shape (tuple): shape of the sample. Default: ().
low (Tensor): lower bound of the distribution. Default: self.low. low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high. high (Tensor): upper bound of the distribution. Default: self.high.
@ -292,7 +277,6 @@ class Uniform(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
if name == 'sample':
low = self.low if low is None else low low = self.low if low is None else low
high = self.high if high is None else high high = self.high if high is None else high
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
@ -301,4 +285,3 @@ class Uniform(Distribution):
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
sample = (high - low) * sample_uniform + low sample = (high - low) * sample_uniform + low
return sample return sample
return None

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__() super(Prob, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.b('prob', x_) return self.b.prob(x_)
def test_pmf(): def test_pmf():
""" """
@ -57,9 +55,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.b('log_prob', x_) return self.b.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
@ -81,9 +78,8 @@ class KL(nn.Cell):
super(KL, self).__init__() super(KL, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.b('kl_loss', 'Bernoulli', x_) return self.b.kl_loss('Bernoulli', x_)
def test_kl_loss(): def test_kl_loss():
""" """
@ -107,9 +103,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__() super(Basics, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
@ms_function
def construct(self): def construct(self):
return self.b('mean'), self.b('sd'), self.b('mode') return self.b.mean(), self.b.sd(), self.b.mode()
def test_basics(): def test_basics():
""" """
@ -134,9 +129,8 @@ class Sampling(nn.Cell):
self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape self.shape = shape
@ms_function
def construct(self, probs=None): def construct(self, probs=None):
return self.b('sample', self.shape, probs) return self.b.sample(self.shape, probs)
def test_sample(): def test_sample():
""" """
@ -155,9 +149,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__() super(CDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.b('cdf', x_) return self.b.cdf(x_)
def test_cdf(): def test_cdf():
""" """
@ -171,7 +164,6 @@ def test_cdf():
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
class LogCDF(nn.Cell): class LogCDF(nn.Cell):
""" """
Test class: log cdf of bernoulli distributions. Test class: log cdf of bernoulli distributions.
@ -180,9 +172,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.b('log_cdf', x_) return self.b.log_cdf(x_)
def test_logcdf(): def test_logcdf():
""" """
@ -205,9 +196,8 @@ class SF(nn.Cell):
super(SF, self).__init__() super(SF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.b('survival_function', x_) return self.b.survival_function(x_)
def test_survival(): def test_survival():
""" """
@ -230,9 +220,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.b('log_survival', x_) return self.b.log_survival(x_)
def test_log_survival(): def test_log_survival():
""" """
@ -254,9 +243,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self): def construct(self):
return self.b('entropy') return self.b.entropy()
def test_entropy(): def test_entropy():
""" """
@ -277,12 +265,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
entropy = self.b('entropy') entropy = self.b.entropy()
kl_loss = self.b('kl_loss', 'Bernoulli', x_) kl_loss = self.b.kl_loss('Bernoulli', x_)
h_sum_kl = entropy + kl_loss h_sum_kl = entropy + kl_loss
cross_entropy = self.b('cross_entropy', 'Bernoulli', x_) cross_entropy = self.b.cross_entropy('Bernoulli', x_)
return h_sum_kl - cross_entropy return h_sum_kl - cross_entropy
def test_cross_entropy(): def test_cross_entropy():

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__() super(Prob, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.e('prob', x_) return self.e.prob(x_)
def test_pdf(): def test_pdf():
""" """
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.e('log_prob', x_) return self.e.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__() super(KL, self).__init__()
self.e = msd.Exponential([1.5], dtype=dtype.float32) self.e = msd.Exponential([1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.e('kl_loss', 'Exponential', x_) return self.e.kl_loss('Exponential', x_)
def test_kl_loss(): def test_kl_loss():
""" """
@ -104,9 +100,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__() super(Basics, self).__init__()
self.e = msd.Exponential([0.5], dtype=dtype.float32) self.e = msd.Exponential([0.5], dtype=dtype.float32)
@ms_function
def construct(self): def construct(self):
return self.e('mean'), self.e('sd'), self.e('mode') return self.e.mean(), self.e.sd(), self.e.mode()
def test_basics(): def test_basics():
""" """
@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
self.shape = shape self.shape = shape
@ms_function
def construct(self, rate=None): def construct(self, rate=None):
return self.e('sample', self.shape, rate) return self.e.sample(self.shape, rate)
def test_sample(): def test_sample():
""" """
@ -154,9 +148,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__() super(CDF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.e('cdf', x_) return self.e.cdf(x_)
def test_cdf(): def test_cdf():
""" """
@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.e('log_cdf', x_) return self.e.log_cdf(x_)
def test_log_cdf(): def test_log_cdf():
""" """
@ -202,9 +194,8 @@ class SF(nn.Cell):
super(SF, self).__init__() super(SF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.e('survival_function', x_) return self.e.survival_function(x_)
def test_survival(): def test_survival():
""" """
@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.e('log_survival', x_) return self.e.log_survival(x_)
def test_log_survival(): def test_log_survival():
""" """
@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self): def construct(self):
return self.e('entropy') return self.e.entropy()
def test_entropy(): def test_entropy():
""" """
@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.e = msd.Exponential([1.0], dtype=dtype.float32) self.e = msd.Exponential([1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
entropy = self.e('entropy') entropy = self.e.entropy()
kl_loss = self.e('kl_loss', 'Exponential', x_) kl_loss = self.e.kl_loss('Exponential', x_)
h_sum_kl = entropy + kl_loss h_sum_kl = entropy + kl_loss
cross_entropy = self.e('cross_entropy', 'Exponential', x_) cross_entropy = self.e.cross_entropy('Exponential', x_)
return h_sum_kl - cross_entropy return h_sum_kl - cross_entropy
def test_cross_entropy(): def test_cross_entropy():

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__() super(Prob, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.g('prob', x_) return self.g.prob(x_)
def test_pmf(): def test_pmf():
""" """
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.g('log_prob', x_) return self.g.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__() super(KL, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.g('kl_loss', 'Geometric', x_) return self.g.kl_loss('Geometric', x_)
def test_kl_loss(): def test_kl_loss():
""" """
@ -106,9 +102,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__() super(Basics, self).__init__()
self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32) self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
@ms_function
def construct(self): def construct(self):
return self.g('mean'), self.g('sd'), self.g('mode') return self.g.mean(), self.g.sd(), self.g.mode()
def test_basics(): def test_basics():
""" """
@ -133,9 +128,8 @@ class Sampling(nn.Cell):
self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape self.shape = shape
@ms_function
def construct(self, probs=None): def construct(self, probs=None):
return self.g('sample', self.shape, probs) return self.g.sample(self.shape, probs)
def test_sample(): def test_sample():
""" """
@ -154,9 +148,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__() super(CDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.g('cdf', x_) return self.g.cdf(x_)
def test_cdf(): def test_cdf():
""" """
@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.g('log_cdf', x_) return self.g.log_cdf(x_)
def test_logcdf(): def test_logcdf():
""" """
@ -202,9 +194,8 @@ class SF(nn.Cell):
super(SF, self).__init__() super(SF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.g('survival_function', x_) return self.g.survival_function(x_)
def test_survival(): def test_survival():
""" """
@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.g('log_survival', x_) return self.g.log_survival(x_)
def test_log_survival(): def test_log_survival():
""" """
@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self): def construct(self):
return self.g('entropy') return self.g.entropy()
def test_entropy(): def test_entropy():
""" """
@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_): def construct(self, x_):
entropy = self.g('entropy') entropy = self.g.entropy()
kl_loss = self.g('kl_loss', 'Geometric', x_) kl_loss = self.g.kl_loss('Geometric', x_)
h_sum_kl = entropy + kl_loss h_sum_kl = entropy + kl_loss
ans = self.g('cross_entropy', 'Geometric', x_) ans = self.g.cross_entropy('Geometric', x_)
return h_sum_kl - ans return h_sum_kl - ans
def test_cross_entropy(): def test_cross_entropy():

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__() super(Prob, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.n('prob', x_) return self.n.prob(x_)
def test_pdf(): def test_pdf():
""" """
@ -55,9 +53,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.n('log_prob', x_) return self.n.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
@ -79,9 +76,8 @@ class KL(nn.Cell):
super(KL, self).__init__() super(KL, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
return self.n('kl_loss', 'Normal', x_, y_) return self.n.kl_loss('Normal', x_, y_)
def test_kl_loss(): def test_kl_loss():
@ -113,9 +109,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__() super(Basics, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
@ms_function
def construct(self): def construct(self):
return self.n('mean'), self.n('sd'), self.n('mode') return self.n.mean(), self.n.sd(), self.n.mode()
def test_basics(): def test_basics():
""" """
@ -139,9 +134,8 @@ class Sampling(nn.Cell):
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
self.shape = shape self.shape = shape
@ms_function
def construct(self, mean=None, sd=None): def construct(self, mean=None, sd=None):
return self.n('sample', self.shape, mean, sd) return self.n.sample(self.shape, mean, sd)
def test_sample(): def test_sample():
""" """
@ -163,9 +157,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__() super(CDF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.n('cdf', x_) return self.n.cdf(x_)
def test_cdf(): def test_cdf():
@ -187,9 +180,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.n('log_cdf', x_) return self.n.log_cdf(x_)
def test_log_cdf(): def test_log_cdf():
""" """
@ -210,9 +202,8 @@ class SF(nn.Cell):
super(SF, self).__init__() super(SF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.n('survival_function', x_) return self.n.survival_function(x_)
def test_survival(): def test_survival():
""" """
@ -233,9 +224,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.n('log_survival', x_) return self.n.log_survival(x_)
def test_log_survival(): def test_log_survival():
""" """
@ -256,9 +246,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self): def construct(self):
return self.n('entropy') return self.n.entropy()
def test_entropy(): def test_entropy():
""" """
@ -279,12 +268,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
entropy = self.n('entropy') entropy = self.n.entropy()
kl_loss = self.n('kl_loss', 'Normal', x_, y_) kl_loss = self.n.kl_loss('Normal', x_, y_)
h_sum_kl = entropy + kl_loss h_sum_kl = entropy + kl_loss
cross_entropy = self.n('cross_entropy', 'Normal', x_, y_) cross_entropy = self.n.cross_entropy('Normal', x_, y_)
return h_sum_kl - cross_entropy return h_sum_kl - cross_entropy
def test_cross_entropy(): def test_cross_entropy():
@ -297,3 +285,40 @@ def test_cross_entropy():
diff = cross_entropy(mean, sd) diff = cross_entropy(mean, sd)
tol = 1e-6 tol = 1e-6
assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
class Net(nn.Cell):
"""
Test class: expand single distribution instance to multiple graphs
by specifying the attributes.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('Normal', x_, y_)
prob = self.normal.prob(kl)
return prob
def test_multiple_graphs():
"""
Test multiple graphs case.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()

@ -1,62 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for new api of normal distribution"""
import numpy as np
from scipy import stats
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype
from mindspore import Tensor
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: new api of normal distribution.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
prob = self.normal.prob('prob', kl)
return prob
def test_new_api():
"""
Test new api of normal distribution.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd import mindspore.nn.probability.distribution as msd
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__() super(Prob, self).__init__()
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.u('prob', x_) return self.u.prob(x_)
def test_pdf(): def test_pdf():
""" """
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.u('log_prob', x_) return self.u.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__() super(KL, self).__init__()
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
return self.u('kl_loss', 'Uniform', x_, y_) return self.u.kl_loss('Uniform', x_, y_)
def test_kl_loss(): def test_kl_loss():
""" """
@ -106,9 +102,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__() super(Basics, self).__init__()
self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32)
@ms_function
def construct(self): def construct(self):
return self.u('mean'), self.u('sd') return self.u.mean(), self.u.sd()
def test_basics(): def test_basics():
""" """
@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32)
self.shape = shape self.shape = shape
@ms_function
def construct(self, low=None, high=None): def construct(self, low=None, high=None):
return self.u('sample', self.shape, low, high) return self.u.sample(self.shape, low, high)
def test_sample(): def test_sample():
""" """
@ -155,9 +149,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__() super(CDF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.u('cdf', x_) return self.u.cdf(x_)
def test_cdf(): def test_cdf():
""" """
@ -179,9 +172,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.u('log_cdf', x_) return self.u.log_cdf(x_)
class SF(nn.Cell): class SF(nn.Cell):
""" """
@ -191,9 +183,8 @@ class SF(nn.Cell):
super(SF, self).__init__() super(SF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.u('survival_function', x_) return self.u.survival_function(x_)
class LogSF(nn.Cell): class LogSF(nn.Cell):
""" """
@ -203,9 +194,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_): def construct(self, x_):
return self.u('log_survival', x_) return self.u.log_survival(x_)
class EntropyH(nn.Cell): class EntropyH(nn.Cell):
""" """
@ -215,9 +205,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32)
@ms_function
def construct(self): def construct(self):
return self.u('entropy') return self.u.entropy()
def test_entropy(): def test_entropy():
""" """
@ -238,12 +227,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_, y_): def construct(self, x_, y_):
entropy = self.u('entropy') entropy = self.u.entropy()
kl_loss = self.u('kl_loss', 'Uniform', x_, y_) kl_loss = self.u.kl_loss('Uniform', x_, y_)
h_sum_kl = entropy + kl_loss h_sum_kl = entropy + kl_loss
cross_entropy = self.u('cross_entropy', 'Uniform', x_, y_) cross_entropy = self.u.cross_entropy('Uniform', x_, y_)
return h_sum_kl - cross_entropy return h_sum_kl - cross_entropy
def test_log_cdf(): def test_log_cdf():

@ -49,12 +49,12 @@ class BernoulliProb(nn.Cell):
self.b = msd.Bernoulli(0.5, dtype=dtype.int32) self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
def construct(self, value): def construct(self, value):
prob = self.b('prob', value) prob = self.b.prob(value)
log_prob = self.b('log_prob', value) log_prob = self.b.log_prob(value)
cdf = self.b('cdf', value) cdf = self.b.cdf(value)
log_cdf = self.b('log_cdf', value) log_cdf = self.b.log_cdf(value)
sf = self.b('survival_function', value) sf = self.b.survival_function(value)
log_sf = self.b('log_survival', value) log_sf = self.b.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob(): def test_bernoulli_prob():
@ -75,12 +75,12 @@ class BernoulliProb1(nn.Cell):
self.b = msd.Bernoulli(dtype=dtype.int32) self.b = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs): def construct(self, value, probs):
prob = self.b('prob', value, probs) prob = self.b.prob(value, probs)
log_prob = self.b('log_prob', value, probs) log_prob = self.b.log_prob(value, probs)
cdf = self.b('cdf', value, probs) cdf = self.b.cdf(value, probs)
log_cdf = self.b('log_cdf', value, probs) log_cdf = self.b.log_cdf(value, probs)
sf = self.b('survival_function', value, probs) sf = self.b.survival_function(value, probs)
log_sf = self.b('log_survival', value, probs) log_sf = self.b.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob1(): def test_bernoulli_prob1():
@ -103,8 +103,8 @@ class BernoulliKl(nn.Cell):
self.b2 = msd.Bernoulli(dtype=dtype.int32) self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
kl1 = self.b1('kl_loss', 'Bernoulli', probs_b) kl1 = self.b1.kl_loss('Bernoulli', probs_b)
kl2 = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
@ -127,8 +127,8 @@ class BernoulliCrossEntropy(nn.Cell):
self.b2 = msd.Bernoulli(dtype=dtype.int32) self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
h1 = self.b1('cross_entropy', 'Bernoulli', probs_b) h1 = self.b1.cross_entropy('Bernoulli', probs_b)
h2 = self.b2('cross_entropy', 'Bernoulli', probs_b, probs_a) h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
@ -150,11 +150,11 @@ class BernoulliBasics(nn.Cell):
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self): def construct(self):
mean = self.b('mean') mean = self.b.mean()
sd = self.b('sd') sd = self.b.sd()
var = self.b('var') var = self.b.var()
mode = self.b('mode') mode = self.b.mode()
entropy = self.b('entropy') entropy = self.b.entropy()
return mean + sd + var + mode + entropy return mean + sd + var + mode + entropy
def test_bascis(): def test_bascis():
@ -164,3 +164,28 @@ def test_bascis():
net = BernoulliBasics() net = BernoulliBasics()
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
self.b1 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.b('prob', value)
prob1 = self.b('prob', value, probs)
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2
def test_bernoulli_construct():
"""
Test probability function going through construct.
"""
net = BernoulliConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)

@ -50,12 +50,12 @@ class ExponentialProb(nn.Cell):
self.e = msd.Exponential(0.5, dtype=dtype.float32) self.e = msd.Exponential(0.5, dtype=dtype.float32)
def construct(self, value): def construct(self, value):
prob = self.e('prob', value) prob = self.e.prob(value)
log_prob = self.e('log_prob', value) log_prob = self.e.log_prob(value)
cdf = self.e('cdf', value) cdf = self.e.cdf(value)
log_cdf = self.e('log_cdf', value) log_cdf = self.e.log_cdf(value)
sf = self.e('survival_function', value) sf = self.e.survival_function(value)
log_sf = self.e('log_survival', value) log_sf = self.e.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_exponential_prob(): def test_exponential_prob():
@ -76,12 +76,12 @@ class ExponentialProb1(nn.Cell):
self.e = msd.Exponential(dtype=dtype.float32) self.e = msd.Exponential(dtype=dtype.float32)
def construct(self, value, rate): def construct(self, value, rate):
prob = self.e('prob', value, rate) prob = self.e.prob(value, rate)
log_prob = self.e('log_prob', value, rate) log_prob = self.e.log_prob(value, rate)
cdf = self.e('cdf', value, rate) cdf = self.e.cdf(value, rate)
log_cdf = self.e('log_cdf', value, rate) log_cdf = self.e.log_cdf(value, rate)
sf = self.e('survival_function', value, rate) sf = self.e.survival_function(value, rate)
log_sf = self.e('log_survival', value, rate) log_sf = self.e.log_survival(value, rate)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_exponential_prob1(): def test_exponential_prob1():
@ -104,8 +104,8 @@ class ExponentialKl(nn.Cell):
self.e2 = msd.Exponential(dtype=dtype.float32) self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a): def construct(self, rate_b, rate_a):
kl1 = self.e1('kl_loss', 'Exponential', rate_b) kl1 = self.e1.kl_loss('Exponential', rate_b)
kl2 = self.e2('kl_loss', 'Exponential', rate_b, rate_a) kl2 = self.e2.kl_loss('Exponential', rate_b, rate_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
@ -128,8 +128,8 @@ class ExponentialCrossEntropy(nn.Cell):
self.e2 = msd.Exponential(dtype=dtype.float32) self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a): def construct(self, rate_b, rate_a):
h1 = self.e1('cross_entropy', 'Exponential', rate_b) h1 = self.e1.cross_entropy('Exponential', rate_b)
h2 = self.e2('cross_entropy', 'Exponential', rate_b, rate_a) h2 = self.e2.cross_entropy('Exponential', rate_b, rate_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
@ -151,11 +151,11 @@ class ExponentialBasics(nn.Cell):
self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32) self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32)
def construct(self): def construct(self):
mean = self.e('mean') mean = self.e.mean()
sd = self.e('sd') sd = self.e.sd()
var = self.e('var') var = self.e.var()
mode = self.e('mode') mode = self.e.mode()
entropy = self.e('entropy') entropy = self.e.entropy()
return mean + sd + var + mode + entropy return mean + sd + var + mode + entropy
def test_bascis(): def test_bascis():
@ -165,3 +165,29 @@ def test_bascis():
net = ExponentialBasics() net = ExponentialBasics()
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class ExpConstruct(nn.Cell):
"""
Exponential distribution: going through construct.
"""
def __init__(self):
super(ExpConstruct, self).__init__()
self.e = msd.Exponential(0.5, dtype=dtype.float32)
self.e1 = msd.Exponential(dtype=dtype.float32)
def construct(self, value, rate):
prob = self.e('prob', value)
prob1 = self.e('prob', value, rate)
prob2 = self.e1('prob', value, rate)
return prob + prob1 + prob2
def test_exp_construct():
"""
Test probability function going through construct.
"""
net = ExpConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)

@ -50,12 +50,12 @@ class GeometricProb(nn.Cell):
self.g = msd.Geometric(0.5, dtype=dtype.int32) self.g = msd.Geometric(0.5, dtype=dtype.int32)
def construct(self, value): def construct(self, value):
prob = self.g('prob', value) prob = self.g.prob(value)
log_prob = self.g('log_prob', value) log_prob = self.g.log_prob(value)
cdf = self.g('cdf', value) cdf = self.g.cdf(value)
log_cdf = self.g('log_cdf', value) log_cdf = self.g.log_cdf(value)
sf = self.g('survival_function', value) sf = self.g.survival_function(value)
log_sf = self.g('log_survival', value) log_sf = self.g.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob(): def test_geometric_prob():
@ -76,12 +76,12 @@ class GeometricProb1(nn.Cell):
self.g = msd.Geometric(dtype=dtype.int32) self.g = msd.Geometric(dtype=dtype.int32)
def construct(self, value, probs): def construct(self, value, probs):
prob = self.g('prob', value, probs) prob = self.g.prob(value, probs)
log_prob = self.g('log_prob', value, probs) log_prob = self.g.log_prob(value, probs)
cdf = self.g('cdf', value, probs) cdf = self.g.cdf(value, probs)
log_cdf = self.g('log_cdf', value, probs) log_cdf = self.g.log_cdf(value, probs)
sf = self.g('survival_function', value, probs) sf = self.g.survival_function(value, probs)
log_sf = self.g('log_survival', value, probs) log_sf = self.g.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob1(): def test_geometric_prob1():
@ -105,8 +105,8 @@ class GeometricKl(nn.Cell):
self.g2 = msd.Geometric(dtype=dtype.int32) self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
kl1 = self.g1('kl_loss', 'Geometric', probs_b) kl1 = self.g1.kl_loss('Geometric', probs_b)
kl2 = self.g2('kl_loss', 'Geometric', probs_b, probs_a) kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
@ -129,8 +129,8 @@ class GeometricCrossEntropy(nn.Cell):
self.g2 = msd.Geometric(dtype=dtype.int32) self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a): def construct(self, probs_b, probs_a):
h1 = self.g1('cross_entropy', 'Geometric', probs_b) h1 = self.g1.cross_entropy('Geometric', probs_b)
h2 = self.g2('cross_entropy', 'Geometric', probs_b, probs_a) h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
@ -152,11 +152,11 @@ class GeometricBasics(nn.Cell):
self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
def construct(self): def construct(self):
mean = self.g('mean') mean = self.g.mean()
sd = self.g('sd') sd = self.g.sd()
var = self.g('var') var = self.g.var()
mode = self.g('mode') mode = self.g.mode()
entropy = self.g('entropy') entropy = self.g.entropy()
return mean + sd + var + mode + entropy return mean + sd + var + mode + entropy
def test_bascis(): def test_bascis():
@ -166,3 +166,29 @@ def test_bascis():
net = GeometricBasics() net = GeometricBasics()
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class GeoConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(GeoConstruct, self).__init__()
self.g = msd.Geometric(0.5, dtype=dtype.int32)
self.g1 = msd.Geometric(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.g('prob', value)
prob1 = self.g('prob', value, probs)
prob2 = self.g1('prob', value, probs)
return prob + prob1 + prob2
def test_geo_construct():
"""
Test probability function going through construct.
"""
net = GeoConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)

@ -50,12 +50,12 @@ class NormalProb(nn.Cell):
self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32) self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self, value): def construct(self, value):
prob = self.normal('prob', value) prob = self.normal.prob(value)
log_prob = self.normal('log_prob', value) log_prob = self.normal.log_prob(value)
cdf = self.normal('cdf', value) cdf = self.normal.cdf(value)
log_cdf = self.normal('log_cdf', value) log_cdf = self.normal.log_cdf(value)
sf = self.normal('survival_function', value) sf = self.normal.survival_function(value)
log_sf = self.normal('log_survival', value) log_sf = self.normal.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_normal_prob(): def test_normal_prob():
@ -77,12 +77,12 @@ class NormalProb1(nn.Cell):
self.normal = msd.Normal() self.normal = msd.Normal()
def construct(self, value, mean, sd): def construct(self, value, mean, sd):
prob = self.normal('prob', value, mean, sd) prob = self.normal.prob(value, mean, sd)
log_prob = self.normal('log_prob', value, mean, sd) log_prob = self.normal.log_prob(value, mean, sd)
cdf = self.normal('cdf', value, mean, sd) cdf = self.normal.cdf(value, mean, sd)
log_cdf = self.normal('log_cdf', value, mean, sd) log_cdf = self.normal.log_cdf(value, mean, sd)
sf = self.normal('survival_function', value, mean, sd) sf = self.normal.survival_function(value, mean, sd)
log_sf = self.normal('log_survival', value, mean, sd) log_sf = self.normal.log_survival(value, mean, sd)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_normal_prob1(): def test_normal_prob1():
@ -106,8 +106,8 @@ class NormalKl(nn.Cell):
self.n2 = msd.Normal(dtype=dtype.float32) self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a): def construct(self, mean_b, sd_b, mean_a, sd_a):
kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b) kl1 = self.n1.kl_loss('Normal', mean_b, sd_b)
kl2 = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) kl2 = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
@ -132,8 +132,8 @@ class NormalCrossEntropy(nn.Cell):
self.n2 = msd.Normal(dtype=dtype.float32) self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a): def construct(self, mean_b, sd_b, mean_a, sd_a):
h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b) h1 = self.n1.cross_entropy('Normal', mean_b, sd_b)
h2 = self.n2('cross_entropy', 'Normal', mean_b, sd_b, mean_a, sd_a) h2 = self.n2.cross_entropy('Normal', mean_b, sd_b, mean_a, sd_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
@ -157,10 +157,10 @@ class NormalBasics(nn.Cell):
self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32) self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self): def construct(self):
mean = self.n('mean') mean = self.n.mean()
sd = self.n('sd') sd = self.n.sd()
mode = self.n('mode') mode = self.n.mode()
entropy = self.n('entropy') entropy = self.n.entropy()
return mean + sd + mode + entropy return mean + sd + mode + entropy
def test_bascis(): def test_bascis():
@ -170,3 +170,30 @@ def test_bascis():
net = NormalBasics() net = NormalBasics()
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class NormalConstruct(nn.Cell):
"""
Normal distribution: going through construct.
"""
def __init__(self):
super(NormalConstruct, self).__init__()
self.normal = msd.Normal(3.0, 4.0)
self.normal1 = msd.Normal()
def construct(self, value, mean, sd):
prob = self.normal('prob', value)
prob1 = self.normal('prob', value, mean, sd)
prob2 = self.normal1('prob', value, mean, sd)
return prob + prob1 + prob2
def test_normal_construct():
"""
Test probability function going through construct.
"""
net = NormalConstruct()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
mean = Tensor([0.0], dtype=dtype.float32)
sd = Tensor([1.0], dtype=dtype.float32)
ans = net(value, mean, sd)
assert isinstance(ans, Tensor)

@ -60,12 +60,12 @@ class UniformProb(nn.Cell):
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self, value): def construct(self, value):
prob = self.u('prob', value) prob = self.u.prob(value)
log_prob = self.u('log_prob', value) log_prob = self.u.log_prob(value)
cdf = self.u('cdf', value) cdf = self.u.cdf(value)
log_cdf = self.u('log_cdf', value) log_cdf = self.u.log_cdf(value)
sf = self.u('survival_function', value) sf = self.u.survival_function(value)
log_sf = self.u('log_survival', value) log_sf = self.u.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob(): def test_uniform_prob():
@ -86,12 +86,12 @@ class UniformProb1(nn.Cell):
self.u = msd.Uniform(dtype=dtype.float32) self.u = msd.Uniform(dtype=dtype.float32)
def construct(self, value, low, high): def construct(self, value, low, high):
prob = self.u('prob', value, low, high) prob = self.u.prob(value, low, high)
log_prob = self.u('log_prob', value, low, high) log_prob = self.u.log_prob(value, low, high)
cdf = self.u('cdf', value, low, high) cdf = self.u.cdf(value, low, high)
log_cdf = self.u('log_cdf', value, low, high) log_cdf = self.u.log_cdf(value, low, high)
sf = self.u('survival_function', value, low, high) sf = self.u.survival_function(value, low, high)
log_sf = self.u('log_survival', value, low, high) log_sf = self.u.log_survival(value, low, high)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob1(): def test_uniform_prob1():
@ -115,8 +115,8 @@ class UniformKl(nn.Cell):
self.u2 = msd.Uniform(dtype=dtype.float32) self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a): def construct(self, low_b, high_b, low_a, high_a):
kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b) kl1 = self.u1.kl_loss('Uniform', low_b, high_b)
kl2 = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
@ -141,8 +141,8 @@ class UniformCrossEntropy(nn.Cell):
self.u2 = msd.Uniform(dtype=dtype.float32) self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a): def construct(self, low_b, high_b, low_a, high_a):
h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b) h1 = self.u1.cross_entropy('Uniform', low_b, high_b)
h2 = self.u2('cross_entropy', 'Uniform', low_b, high_b, low_a, high_a) h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
@ -166,10 +166,10 @@ class UniformBasics(nn.Cell):
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self): def construct(self):
mean = self.u('mean') mean = self.u.mean()
sd = self.u('sd') sd = self.u.sd()
var = self.u('var') var = self.u.var()
entropy = self.u('entropy') entropy = self.u.entropy()
return mean + sd + var + entropy return mean + sd + var + entropy
def test_bascis(): def test_bascis():
@ -179,3 +179,30 @@ def test_bascis():
net = UniformBasics() net = UniformBasics()
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class UniConstruct(nn.Cell):
"""
Unifrom distribution: going through construct.
"""
def __init__(self):
super(UniConstruct, self).__init__()
self.u = msd.Uniform(-4.0, 4.0)
self.u1 = msd.Uniform()
def construct(self, value, low, high):
prob = self.u('prob', value)
prob1 = self.u('prob', value, low, high)
prob2 = self.u1('prob', value, low, high)
return prob + prob1 + prob2
def test_uniform_construct():
"""
Test probability function going through construct.
"""
net = UniConstruct()
value = Tensor([-5.0, 0.0, 1.0, 5.0], dtype=dtype.float32)
low = Tensor([-1.0], dtype=dtype.float32)
high = Tensor([1.0], dtype=dtype.float32)
ans = net(value, low, high)
assert isinstance(ans, Tensor)

Loading…
Cancel
Save