Fix errors in transformed_distribution

pull/4894/head
peixu_ren 5 years ago
parent 3449abd741
commit 98e2a48e4c

@ -104,7 +104,7 @@ class TransformedDistribution(Distribution):
return self.exp(self._log_prob(*args, **kwargs)) return self.exp(self._log_prob(*args, **kwargs))
def _sample(self, *args, **kwargs): def _sample(self, *args, **kwargs):
org_sample = self.distribution("sample", shape) org_sample = self.distribution("sample", *args, **kwargs)
return self.bijector("forward", org_sample) return self.bijector("forward", org_sample)
def _mean(self, *args, **kwargs): def _mean(self, *args, **kwargs):
@ -113,6 +113,6 @@ class TransformedDistribution(Distribution):
This function maybe overridden by derived class. This function maybe overridden by derived class.
""" """
if not self.is_linear_transformation: if not self.is_linear_transformation:
raise_not_impl_error(mean) raise_not_impl_error("mean")
return self.bijector("forward", self.distribution("mean")) return self.bijector("forward", self.distribution("mean"))

Loading…
Cancel
Save