|
|
|
@ -14,13 +14,14 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""basic"""
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
|
from mindspore.common import get_seed
|
|
|
|
|
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\
|
|
|
|
|
raise_none_error
|
|
|
|
|
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device
|
|
|
|
|
from ._utils.utils import CheckTuple, CheckTensor
|
|
|
|
|
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Distribution(Cell):
|
|
|
|
@ -68,14 +69,16 @@ class Distribution(Cell):
|
|
|
|
|
self._seed = seed
|
|
|
|
|
self._dtype = cast_type_for_device(dtype)
|
|
|
|
|
self._parameters = {}
|
|
|
|
|
|
|
|
|
|
# parsing parameters
|
|
|
|
|
for k in param.keys():
|
|
|
|
|
if not(k == 'self' or k.startswith('_')):
|
|
|
|
|
self._parameters[k] = param[k]
|
|
|
|
|
|
|
|
|
|
# some attributes
|
|
|
|
|
self._broadcast_shape = calc_broadcast_shape_from_param(
|
|
|
|
|
self.parameters)
|
|
|
|
|
self._is_scalar_batch = check_scalar_from_param(self.parameters)
|
|
|
|
|
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
|
|
|
|
|
self._broadcast_shape = self._calc_broadcast_shape()
|
|
|
|
|
self._is_scalar_batch = self._check_is_scalar_batch()
|
|
|
|
|
|
|
|
|
|
# set the function to call according to the derived class's attributes
|
|
|
|
|
self._set_prob()
|
|
|
|
@ -91,6 +94,18 @@ class Distribution(Cell):
|
|
|
|
|
self.context_mode = context.get_context('mode')
|
|
|
|
|
self.checktuple = CheckTuple()
|
|
|
|
|
self.checktensor = CheckTensor()
|
|
|
|
|
self.broadcast = broadcast_to
|
|
|
|
|
|
|
|
|
|
# ops needed for the base class
|
|
|
|
|
self.cast_base = P.Cast()
|
|
|
|
|
self.dtype_base = P.DType()
|
|
|
|
|
self.exp_base = exp_generic
|
|
|
|
|
self.fill_base = P.Fill()
|
|
|
|
|
self.log_base = log_generic
|
|
|
|
|
self.sametypeshape_base = P.SameTypeShape()
|
|
|
|
|
self.sq_base = P.Square()
|
|
|
|
|
self.sqrt_base = P.Sqrt()
|
|
|
|
|
self.shape_base = P.Shape()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def name(self):
|
|
|
|
@ -116,6 +131,21 @@ class Distribution(Cell):
|
|
|
|
|
def broadcast_shape(self):
|
|
|
|
|
return self._broadcast_shape
|
|
|
|
|
|
|
|
|
|
def _add_parameter(self, value, name):
|
|
|
|
|
"""
|
|
|
|
|
Cast `value` to a tensor and add it to `self.default_parameters`.
|
|
|
|
|
Add `name` into and `self.parameter_names`.
|
|
|
|
|
"""
|
|
|
|
|
# initialize the attributes if they do not exist yet
|
|
|
|
|
if not hasattr(self, 'default_parameters'):
|
|
|
|
|
self.default_parameters = []
|
|
|
|
|
self.parameter_names = []
|
|
|
|
|
# cast value to a tensor if it is not None
|
|
|
|
|
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
|
|
|
|
|
self.default_parameters += [value_t,]
|
|
|
|
|
self.parameter_names += [name,]
|
|
|
|
|
return value_t
|
|
|
|
|
|
|
|
|
|
def _check_param_type(self, *args):
|
|
|
|
|
"""
|
|
|
|
|
Check the availability and validity of default parameters and `dist_spec_args`.
|
|
|
|
@ -123,6 +153,7 @@ class Distribution(Cell):
|
|
|
|
|
are None, the parameters must be passed in through `args`.
|
|
|
|
|
"""
|
|
|
|
|
broadcast_shape = None
|
|
|
|
|
broadcast_shape_tensor = None
|
|
|
|
|
common_dtype = None
|
|
|
|
|
out = []
|
|
|
|
|
|
|
|
|
@ -139,17 +170,17 @@ class Distribution(Cell):
|
|
|
|
|
|
|
|
|
|
# broadcast if the number of args > 1
|
|
|
|
|
if broadcast_shape is None:
|
|
|
|
|
broadcast_shape = self.shape(arg)
|
|
|
|
|
common_dtype = self.dtypeop(arg)
|
|
|
|
|
broadcast_shape = self.shape_base(arg)
|
|
|
|
|
common_dtype = self.dtype_base(arg)
|
|
|
|
|
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
|
|
|
|
|
else:
|
|
|
|
|
ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
|
|
|
|
broadcast_shape = self.shape(arg + ones)
|
|
|
|
|
|
|
|
|
|
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
|
|
|
|
|
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
|
|
|
|
|
arg = self.broadcast(arg, broadcast_shape_tensor)
|
|
|
|
|
# check if the arguments have the same dtype
|
|
|
|
|
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
|
|
|
|
dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0)
|
|
|
|
|
self.sametypeshape(arg, dtype_tensor)
|
|
|
|
|
arg = self.cast(arg, self.parameter_type)
|
|
|
|
|
self.sametypeshape_base(arg, broadcast_shape_tensor)
|
|
|
|
|
|
|
|
|
|
arg = self.cast_base(arg, self.parameter_type)
|
|
|
|
|
out.append(arg)
|
|
|
|
|
|
|
|
|
|
if len(out) == 1:
|
|
|
|
@ -158,7 +189,7 @@ class Distribution(Cell):
|
|
|
|
|
# broadcast all args to broadcast_shape
|
|
|
|
|
result = ()
|
|
|
|
|
for arg in out:
|
|
|
|
|
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
|
|
|
|
arg = self.broadcast(arg, broadcast_shape_tensor)
|
|
|
|
|
result = result + (arg,)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
@ -171,6 +202,38 @@ class Distribution(Cell):
|
|
|
|
|
return value
|
|
|
|
|
return self.checktensor(value, name)
|
|
|
|
|
|
|
|
|
|
def _check_is_scalar_batch(self):
|
|
|
|
|
"""
|
|
|
|
|
Check if the parameters used during initialization are scalars.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, 'distribution'):
|
|
|
|
|
return self._distribution.is_scalar_batch
|
|
|
|
|
param_dict = self.parameters['param_dict']
|
|
|
|
|
for value in param_dict.values():
|
|
|
|
|
if value is None:
|
|
|
|
|
continue
|
|
|
|
|
if not isinstance(value, (int, float)):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _calc_broadcast_shape(self):
|
|
|
|
|
"""
|
|
|
|
|
Calculate the broadcast shape of the parameters used during initialization.
|
|
|
|
|
"""
|
|
|
|
|
if hasattr(self, 'distribution'):
|
|
|
|
|
return self._distribution.broadcast_shape
|
|
|
|
|
param_dict = self.parameters['param_dict']
|
|
|
|
|
broadcast_shape_tensor = None
|
|
|
|
|
for value in param_dict.values():
|
|
|
|
|
if value is None:
|
|
|
|
|
return None
|
|
|
|
|
if broadcast_shape_tensor is None:
|
|
|
|
|
broadcast_shape_tensor = cast_to_tensor(value)
|
|
|
|
|
else:
|
|
|
|
|
value = cast_to_tensor(value)
|
|
|
|
|
broadcast_shape_tensor = (value + broadcast_shape_tensor)
|
|
|
|
|
return broadcast_shape_tensor.shape
|
|
|
|
|
|
|
|
|
|
def _set_prob(self):
|
|
|
|
|
"""
|
|
|
|
|
Set probability funtion based on the availability of `_prob` and `_log_likehood`.
|
|
|
|
@ -280,7 +343,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
probability(x) = \exp(log_likehood(x))
|
|
|
|
|
"""
|
|
|
|
|
return self.exp(self._log_prob(value, *args, **kwargs))
|
|
|
|
|
return self.exp_base(self._log_prob(value, *args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def prob(self, value, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -304,7 +367,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
log_prob(x) = \log(prob(x))
|
|
|
|
|
"""
|
|
|
|
|
return self.log(self._prob(value, *args, **kwargs))
|
|
|
|
|
return self.log_base(self._prob(value, *args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def cdf(self, value, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -328,7 +391,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
cdf(x) = \exp(log_cdf(x))
|
|
|
|
|
"""
|
|
|
|
|
return self.exp(self._log_cdf(value, *args, **kwargs))
|
|
|
|
|
return self.exp_base(self._log_cdf(value, *args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def _calc_cdf_from_survival(self, value, *args, **kwargs):
|
|
|
|
|
r"""
|
|
|
|
@ -346,7 +409,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
cdf(x) = 1 - (\exp(log_survival(x)))
|
|
|
|
|
"""
|
|
|
|
|
return 1.0 - self.exp(self._log_survival(value, *args, **kwargs))
|
|
|
|
|
return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def log_cdf(self, value, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -370,7 +433,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
log_cdf(x) = \log(cdf(x))
|
|
|
|
|
"""
|
|
|
|
|
return self.log(self._call_cdf(value, *args, **kwargs))
|
|
|
|
|
return self.log_base(self._call_cdf(value, *args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def survival_function(self, value, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -403,7 +466,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
survival(x) = \exp(survival_function(x))
|
|
|
|
|
"""
|
|
|
|
|
return self.exp(self._log_survival(value, *args, **kwargs))
|
|
|
|
|
return self.exp_base(self._log_survival(value, *args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def log_survival(self, value, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -427,7 +490,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
log_survival(x) = \log(survival_function(x))
|
|
|
|
|
"""
|
|
|
|
|
return self.log(self._call_survival(value, *args, **kwargs))
|
|
|
|
|
return self.log_base(self._call_survival(value, *args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def kl_loss(self, dist, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -507,7 +570,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
STD(x) = \sqrt(VAR(x))
|
|
|
|
|
"""
|
|
|
|
|
return self.sqrt(self._var(*args, **kwargs))
|
|
|
|
|
return self.sqrt_base(self._var(*args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def _calc_var_from_sd(self, *args, **kwargs):
|
|
|
|
|
r"""
|
|
|
|
@ -516,7 +579,7 @@ class Distribution(Cell):
|
|
|
|
|
.. math::
|
|
|
|
|
VAR(x) = STD(x) ^ 2
|
|
|
|
|
"""
|
|
|
|
|
return self.sq(self._sd(*args, **kwargs))
|
|
|
|
|
return self.sq_base(self._sd(*args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def entropy(self, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|