From 98e2a48e4c6ab39f2f64aed8f0863614dd612105 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 20 Aug 2020 22:59:47 -0400 Subject: [PATCH] Fix errors in transformed_distribution --- .../nn/probability/distribution/transformed_distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 259f105d4e..1ec53dbf64 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -104,7 +104,7 @@ class TransformedDistribution(Distribution): return self.exp(self._log_prob(*args, **kwargs)) def _sample(self, *args, **kwargs): - org_sample = self.distribution("sample", shape) + org_sample = self.distribution("sample", *args, **kwargs) return self.bijector("forward", org_sample) def _mean(self, *args, **kwargs): @@ -113,6 +113,6 @@ class TransformedDistribution(Distribution): This function maybe overridden by derived class. """ if not self.is_linear_transformation: - raise_not_impl_error(mean) + raise_not_impl_error("mean") return self.bijector("forward", self.distribution("mean"))