|
|
|
@ -105,7 +105,7 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
return self.exp(self._log_prob(*args, **kwargs))
|
|
|
|
|
|
|
|
|
|
def _sample(self, *args, **kwargs):
|
|
|
|
|
org_sample = self.distribution("sample", shape)
|
|
|
|
|
org_sample = self.distribution("sample", *args, **kwargs)
|
|
|
|
|
return self.bijector("forward", org_sample)
|
|
|
|
|
|
|
|
|
|
def _mean(self, *args, **kwargs):
|
|
|
|
@ -114,6 +114,6 @@ class TransformedDistribution(Distribution):
|
|
|
|
|
This function maybe overridden by derived class.
|
|
|
|
|
"""
|
|
|
|
|
if not self.is_linear_transformation:
|
|
|
|
|
raise_not_impl_error(mean)
|
|
|
|
|
raise_not_impl_error("mean")
|
|
|
|
|
|
|
|
|
|
return self.bijector("forward", self.distribution("mean"))
|
|
|
|
|