|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
"""basic"""
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
|
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
|
|
|
|
|
|
|
|
|
|
class Distribution(Cell):
|
|
|
|
@ -28,12 +29,15 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Derived class should override operations such as ,_mean, _prob,
|
|
|
|
|
and _log_prob. Arguments should be passed in through *args or **kwargs.
|
|
|
|
|
and _log_prob. Required arguments, such as value for _prob,
|
|
|
|
|
should be passed in through args or kwargs. dist_spec_args which specify
|
|
|
|
|
a new distribution are optional.
|
|
|
|
|
|
|
|
|
|
Dist_spec_args are unique for each type of distribution. For example, mean and sd
|
|
|
|
|
are the dist_spec_args for a Normal distribution.
|
|
|
|
|
dist_spec_args are unique for each type of distribution. For example, mean and sd
|
|
|
|
|
are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args
|
|
|
|
|
for exponential distribution.
|
|
|
|
|
|
|
|
|
|
For all functions, passing in dist_spec_args, are optional.
|
|
|
|
|
For all functions, passing in dist_spec_args, is optional.
|
|
|
|
|
Passing in the additional dist_spec_args will make the result to be evaluated with
|
|
|
|
|
new distribution specified by the dist_spec_args. But it won't change the
|
|
|
|
|
original distribuion.
|
|
|
|
@ -49,7 +53,7 @@ class Distribution(Cell):
|
|
|
|
|
"""
|
|
|
|
|
super(Distribution, self).__init__()
|
|
|
|
|
validator.check_value_type('name', name, [str], 'distribution_name')
|
|
|
|
|
validator.check_value_type('seed', seed, [int], name)
|
|
|
|
|
validator.check_integer('seed', seed, 0, Rel.GE, name)
|
|
|
|
|
|
|
|
|
|
self._name = name
|
|
|
|
|
self._seed = seed
|
|
|
|
@ -191,7 +195,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_log_prob(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -210,7 +214,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_prob(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -229,7 +233,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_cdf(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -266,7 +270,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_log_cdf(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -285,7 +289,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_survival(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -313,7 +317,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Args must include value.
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_log_survival(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -341,7 +345,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the mean.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._mean(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -350,7 +354,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the mode.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._mode(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -359,7 +363,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the standard deviation.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_sd(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -368,7 +372,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the variance.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._call_var(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -395,7 +399,7 @@ class Distribution(Cell):
|
|
|
|
|
Evaluate the entropy.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._entropy(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -424,7 +428,7 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Shape of the sample is default to ().
|
|
|
|
|
Dist_spec_args are optional.
|
|
|
|
|
dist_spec_args are optional.
|
|
|
|
|
"""
|
|
|
|
|
return self._sample(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|