|
|
|
@ -27,25 +27,24 @@ class Distribution(Cell):
|
|
|
|
|
Base class for all mathematical distributions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
seed (int): random seed used in sampling. Global seed is used if it is None. Default: None.
|
|
|
|
|
dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype.
|
|
|
|
|
name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
|
|
|
|
|
param (dict): parameters used to initialize the distribution.
|
|
|
|
|
seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
|
|
|
|
|
dtype (mindspore.dtype): The type of the event samples. Default: subclass dtype.
|
|
|
|
|
name (str): Python string name prefixed to operations created by this class. Default: subclass name.
|
|
|
|
|
param (dict): The parameters used to initialize the distribution.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Derived class should override operations such as ,_mean, _prob,
|
|
|
|
|
and _log_prob. Required arguments, such as value for _prob,
|
|
|
|
|
should be passed in through args or kwargs. dist_spec_args which specify
|
|
|
|
|
Derived class should override operations such as `_mean`, `_prob`,
|
|
|
|
|
and `_log_prob`. Required arguments, such as value for `_prob`,
|
|
|
|
|
should be passed in through `args` or `kwargs`. dist_spec_args which specify
|
|
|
|
|
a new distribution are optional.
|
|
|
|
|
|
|
|
|
|
dist_spec_args are unique for each type of distribution. For example, mean and sd
|
|
|
|
|
are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args
|
|
|
|
|
dist_spec_args are unique for each type of distribution. For example, `mean` and `sd`
|
|
|
|
|
are the dist_spec_args for a Normal distribution, while `rate` is the dist_spec_args
|
|
|
|
|
for exponential distribution.
|
|
|
|
|
|
|
|
|
|
For all functions, passing in dist_spec_args, is optional.
|
|
|
|
|
Passing in the additional dist_spec_args will make the result to be evaluated with
|
|
|
|
|
new distribution specified by the dist_spec_args. But it won't change the
|
|
|
|
|
original distribuion.
|
|
|
|
|
Passing in the additional dist_spec_args will evaluate the result to be evaluated with
|
|
|
|
|
new distribution specified by the dist_spec_args. But it will not change the original distribution.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
@ -118,7 +117,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _check_value(self, value, name):
|
|
|
|
|
"""
|
|
|
|
|
Check availability fo value as a Tensor.
|
|
|
|
|
Check availability of `value` as a Tensor.
|
|
|
|
|
"""
|
|
|
|
|
if self.context_mode == 0:
|
|
|
|
|
self.checktensor(value, name)
|
|
|
|
@ -127,7 +126,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_prob(self):
|
|
|
|
|
"""
|
|
|
|
|
Set probability funtion based on the availability of _prob and _log_likehood.
|
|
|
|
|
Set probability funtion based on the availability of `_prob` and `_log_likehood`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_prob'):
|
|
|
|
|
self._call_prob = self._prob
|
|
|
|
@ -136,7 +135,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_sd(self):
|
|
|
|
|
"""
|
|
|
|
|
Set standard deviation based on the availability of _sd and _var.
|
|
|
|
|
Set standard deviation based on the availability of `_sd` and `_var`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_sd'):
|
|
|
|
|
self._call_sd = self._sd
|
|
|
|
@ -145,7 +144,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_var(self):
|
|
|
|
|
"""
|
|
|
|
|
Set variance based on the availability of _sd and _var.
|
|
|
|
|
Set variance based on the availability of `_sd` and `_var`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_var'):
|
|
|
|
|
self._call_var = self._var
|
|
|
|
@ -154,7 +153,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_log_prob(self):
|
|
|
|
|
"""
|
|
|
|
|
Set log probability based on the availability of _prob and _log_prob.
|
|
|
|
|
Set log probability based on the availability of `_prob` and `_log_prob`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_log_prob'):
|
|
|
|
|
self._call_log_prob = self._log_prob
|
|
|
|
@ -163,7 +162,8 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_cdf(self):
|
|
|
|
|
"""
|
|
|
|
|
Set cdf based on the availability of _cdf and _log_cdf and survival_functions.
|
|
|
|
|
Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and
|
|
|
|
|
`survival_functions`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_cdf'):
|
|
|
|
|
self._call_cdf = self._cdf
|
|
|
|
@ -176,8 +176,8 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_survival(self):
|
|
|
|
|
"""
|
|
|
|
|
Set survival function based on the availability of _survival function and _log_survival
|
|
|
|
|
and _call_cdf.
|
|
|
|
|
Set survival function based on the availability of _survival function and `_log_survival`
|
|
|
|
|
and `_call_cdf`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_survival_function'):
|
|
|
|
|
self._call_survival = self._survival_function
|
|
|
|
@ -188,7 +188,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_log_cdf(self):
|
|
|
|
|
"""
|
|
|
|
|
Set log cdf based on the availability of _log_cdf and _call_cdf.
|
|
|
|
|
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_log_cdf'):
|
|
|
|
|
self._call_log_cdf = self._log_cdf
|
|
|
|
@ -197,7 +197,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_log_survival(self):
|
|
|
|
|
"""
|
|
|
|
|
Set log survival based on the availability of _log_survival and _call_survival.
|
|
|
|
|
Set log survival based on the availability of `_log_survival` and `_call_survival`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_log_survival'):
|
|
|
|
|
self._call_log_survival = self._log_survival
|
|
|
|
@ -206,7 +206,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def _set_cross_entropy(self):
|
|
|
|
|
"""
|
|
|
|
|
Set log survival based on the availability of _cross_entropy.
|
|
|
|
|
Set log survival based on the availability of `_cross_entropy`.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, '_cross_entropy'):
|
|
|
|
|
self._call_cross_entropy = self._cross_entropy
|
|
|
|
@ -216,7 +216,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the log probability(pdf or pmf) at the given value.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
The argument `args` must include `value`.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_log_prob(*args, **kwargs)
|
|
|
|
@ -235,7 +235,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the probability (pdf or pmf) at given value.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
The argument `args` must include `value`.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_prob(*args, **kwargs)
|
|
|
|
@ -254,7 +254,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the cdf at given value.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
The argument `args` must include `value`.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_cdf(*args, **kwargs)
|
|
|
|
@ -291,7 +291,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the log cdf at given value.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
The argument `args` must include `value`.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_log_cdf(*args, **kwargs)
|
|
|
|
@ -310,7 +310,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the survival function at given value.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
The argument `args` must include `value`.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_survival(*args, **kwargs)
|
|
|
|
@ -338,7 +338,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the log survival function at given value.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
The arguments `args` must include `value`.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_log_survival(*args, **kwargs)
|
|
|
|
@ -357,7 +357,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the KL divergence, i.e. KL(a||b).
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include type of the distribution, parameters of distribution b.
|
|
|
|
|
The argument `args` must include the type of the distribution, parameters of distribution b.
|
|
|
|
|
Parameters for distribution a are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._kl_loss(*args, **kwargs)
|
|
|
|
@ -430,7 +430,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the cross_entropy between distribution a and b.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include type of the distribution, parameters of distribution b.
|
|
|
|
|
The argument `args` must include the type of the distribution, parameters of distribution b.
|
|
|
|
|
Parameters for distribution a are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_cross_entropy(*args, **kwargs)
|
|
|
|
@ -456,17 +456,17 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
def construct(self, name, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Override construct in Cell.
|
|
|
|
|
Override `construct` in Cell.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Names of supported functions include:
|
|
|
|
|
'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival'
|
|
|
|
|
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.
|
|
|
|
|
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', and 'sample'.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): name of the function.
|
|
|
|
|
*args (list): list of positional arguments needed for the function.
|
|
|
|
|
**kwargs (dictionary): dictionary of keyword arguments needed for the function.
|
|
|
|
|
name (str): The name of the function.
|
|
|
|
|
*args (list): A list of positional arguments that the function needs.
|
|
|
|
|
**kwargs (dictionary): A dictionary of keyword arguments that the function needs.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if name == 'log_prob':
|
|
|
|
|