!4874 Change the interfaces in transformation base class

Merge pull request !4874 from peixu_ren/custom_bijector
pull/4874/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 5ab41b1c26

@ -275,6 +275,10 @@ def check_type(data_type, value_type, name):
def raise_none_error(name): def raise_none_error(name):
raise ValueError(f"{name} should be specified. Value cannot be None") raise ValueError(f"{name} should be specified. Value cannot be None")
@constexpr
def raise_not_impl_error(name):
raise ValueError(f"{name} function should be implemented for non-linear transformation")
@constexpr @constexpr
def check_distribution_name(name, expected_name): def check_distribution_name(name, expected_name):
if name is None: if name is None:

@ -18,7 +18,7 @@ from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_type from ._utils.utils import check_type, raise_not_impl_error
class TransformedDistribution(Distribution): class TransformedDistribution(Distribution):
""" """
@ -56,6 +56,7 @@ class TransformedDistribution(Distribution):
self._distribution = distribution self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian self._is_linear_transformation = bijector.is_constant_jacobian
self.exp = P.Exp() self.exp = P.Exp()
self.log = P.Log()
@property @property
def bijector(self): def bijector(self):
@ -69,37 +70,49 @@ class TransformedDistribution(Distribution):
def is_linear_transformation(self): def is_linear_transformation(self):
return self._is_linear_transformation return self._is_linear_transformation
def _cdf(self, value): def _cdf(self, *args, **kwargs):
r""" r"""
.. math:: .. math::
Y = g(X) Y = g(X)
P(Y <= a) = P(X <= g^{-1}(a)) P(Y <= a) = P(X <= g^{-1}(a))
""" """
inverse_value = self.bijector.inverse(value) inverse_value = self.bijector("inverse", *args, **kwargs)
return self.distribution.cdf(inverse_value) return self.distribution("cdf", inverse_value)
def _log_prob(self, value): def _log_cdf(self, *args, **kwargs):
return self.log(self._cdf(*args, **kwargs))
def _survival_function(self, *args, **kwargs):
return 1.0 - self._cdf(*args, **kwargs)
def _log_survival(self, *args, **kwargs):
return self.log(self._survival_function(*args, **kwargs))
def _log_prob(self, *args, **kwargs):
r""" r"""
.. math:: .. math::
Y = g(X) Y = g(X)
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a) Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a)) \log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
""" """
inverse_value = self.bijector.inverse(value) inverse_value = self.bijector("inverse", *args, **kwargs)
unadjust_prob = self.distribution.log_prob(inverse_value) unadjust_prob = self.distribution("log_prob", inverse_value)
log_jacobian = self.bijector.inverse_log_jacobian(value) log_jacobian = self.bijector("inverse_log_jacobian", *args, **kwargs)
return unadjust_prob + log_jacobian return unadjust_prob + log_jacobian
def _prob(self, value): def _prob(self, *args, **kwargs):
return self.exp(self._log_prob(value)) return self.exp(self._log_prob(*args, **kwargs))
def _sample(self, shape): def _sample(self, *args, **kwargs):
org_sample = self.distribution.sample(shape) org_sample = self.distribution("sample", shape)
return self.bijector.forward(org_sample) return self.bijector("forward", org_sample)
def _mean(self): def _mean(self, *args, **kwargs):
""" """
Note: Note:
This function maybe overridden by derived class. This function maybe overridden by derived class.
""" """
return self.bijector.forward(self.distribution.mean()) if not self.is_linear_transformation:
raise_not_impl_error(mean)
return self.bijector("forward", self.distribution("mean"))

Loading…
Cancel
Save