From 9083e9dcd616b0f32c4b07be5b1cb52fd183304e Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Mon, 31 Aug 2020 12:42:05 -0400 Subject: [PATCH] fixed prob, survival function of exponential distribution --- .../probability/distribution/exponential.py | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 43886ec096..3ecc8250bb 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -198,9 +198,9 @@ class Exponential(Distribution): return self._entropy(rate) + self._kl_loss(dist, rate_b, rate) - def _prob(self, value, rate=None): + def _log_prob(self, value, rate=None): r""" - pdf of Exponential distribution. + log_pdf of Exponential distribution. Args: Args: @@ -211,15 +211,16 @@ class Exponential(Distribution): Value should be greater or equal to zero. .. math:: - pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 + log_pdf(x) = \log(rate) - rate * x if x >= 0 else 0 """ value = self._check_value(value, "value") value = self.cast(value, self.dtype) rate = self._check_param(rate) - prob = self.exp(self.log(rate) - rate * value) + prob = self.log(rate) - rate * value zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) + neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf) comp = self.less(value, zeros) - return self.select(comp, zeros, prob) + return self.select(comp, neginf, prob) def _cdf(self, value, rate=None): r""" @@ -243,6 +244,27 @@ class Exponential(Distribution): comp = self.less(value, zeros) return self.select(comp, zeros, cdf) + def _log_survival(self, value, rate=None): + r""" + log survival_function of Exponential distribution. + + Args: + value (Tensor): value to be evaluated. + rate (Tensor): rate of the distribution. Default: self.rate. + + Note: + Value should be greater or equal to zero. + + .. math:: + log_survival_function(x) = -1 * \lambda * x if x >= 0 else 0 + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + rate = self._check_param(rate) + sf = -1. * rate * value + zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, sf) def _kl_loss(self, dist, rate_b, rate=None): """