|
|
|
@ -14,10 +14,9 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Transformed Distribution"""
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import check_type, raise_not_impl_error
|
|
|
|
|
from ._utils.utils import raise_not_impl_error
|
|
|
|
|
from ._utils.custom_ops import exp_generic, log_generic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -30,7 +29,6 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
Args:
|
|
|
|
|
bijector (Bijector): The transformation to perform.
|
|
|
|
|
distribution (Distribution): The original distribution.
|
|
|
|
|
dtype (mindspore.dtype): The type of the event samples.
|
|
|
|
|
seed (int): The seed is used in sampling. The global seed is used if it is None.
|
|
|
|
|
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.
|
|
|
|
|
|
|
|
|
@ -45,16 +43,14 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
>>> import mindspore.nn.probability.distribution as msd
|
|
|
|
|
>>> import mindspore.nn.probability.bijector as msb
|
|
|
|
|
>>> ln = msd.TransformedDistribution(msb.Exp(),
|
|
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
|
|
|
|
|
>>> dtype=mstype.float32)
|
|
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
|
|
|
|
|
>>>
|
|
|
|
|
>>> # To use a transformed distribution in a network.
|
|
|
|
|
>>> class net(Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(net, self).__init__():
|
|
|
|
|
>>> self.ln = msd.TransformedDistribution(msb.Exp(),
|
|
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
|
|
|
|
|
>>> dtype=mstype.float32)
|
|
|
|
|
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
|
|
|
|
|
>>>
|
|
|
|
|
>>> def construct(self, value):
|
|
|
|
|
>>> # Similar calls can be made to other functions
|
|
|
|
@ -65,7 +61,6 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
bijector,
|
|
|
|
|
distribution,
|
|
|
|
|
dtype,
|
|
|
|
|
seed=None,
|
|
|
|
|
name="transformed_distribution"):
|
|
|
|
|
"""
|
|
|
|
@ -76,9 +71,7 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
[nn.probability.bijector.Bijector], type(self).__name__)
|
|
|
|
|
validator.check_value_type('distribution', distribution,
|
|
|
|
|
[Distribution], type(self).__name__)
|
|
|
|
|
valid_dtype = mstype.number_type
|
|
|
|
|
check_type(dtype, valid_dtype, type(self).__name__)
|
|
|
|
|
super(TransformedDistribution, self).__init__(seed, dtype, name, param)
|
|
|
|
|
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param)
|
|
|
|
|
|
|
|
|
|
self._bijector = bijector
|
|
|
|
|
self._distribution = distribution
|
|
|
|
@ -96,6 +89,10 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
def distribution(self):
|
|
|
|
|
return self._distribution
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dtype(self):
|
|
|
|
|
return self.distribution.dtype
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def is_linear_transformation(self):
|
|
|
|
|
return self._is_linear_transformation
|
|
|
|
|