|
|
|
@ -13,7 +13,9 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Transformed Distribution"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import raise_not_impl_error
|
|
|
|
@ -80,6 +82,8 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
self.parameter_names = distribution.parameter_names
|
|
|
|
|
self.exp = exp_generic
|
|
|
|
|
self.log = log_generic
|
|
|
|
|
self.equal_base = P.Equal()
|
|
|
|
|
self.select_base = P.Select()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def bijector(self):
|
|
|
|
@ -125,7 +129,8 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
inverse_value = self.bijector("inverse", value)
|
|
|
|
|
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
|
|
|
|
|
log_jacobian = self.bijector("inverse_log_jacobian", value)
|
|
|
|
|
return unadjust_prob + log_jacobian
|
|
|
|
|
isneginf = self.equal_base(unadjust_prob, -np.inf)
|
|
|
|
|
return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian)
|
|
|
|
|
|
|
|
|
|
def _prob(self, value, *args, **kwargs):
|
|
|
|
|
return self.exp(self._log_prob(value, *args, **kwargs))
|
|
|
|
|