From 322f154abb25658ec64869c3220dc407f75fc8c0 Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Tue, 29 Sep 2020 16:13:37 -0400 Subject: [PATCH] fixed neginf plus zero issue under fp16 --- .../probability/distribution/transformed_distribution.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index cab8f2662f..883e0ed4ee 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================ """Transformed Distribution""" +import numpy as np from mindspore._checkparam import Validator as validator +from mindspore.ops import operations as P import mindspore.nn as nn from .distribution import Distribution from ._utils.utils import raise_not_impl_error @@ -80,6 +82,8 @@ class TransformedDistribution(Distribution): self.parameter_names = distribution.parameter_names self.exp = exp_generic self.log = log_generic + self.equal_base = P.Equal() + self.select_base = P.Select() @property def bijector(self): @@ -125,7 +129,8 @@ class TransformedDistribution(Distribution): inverse_value = self.bijector("inverse", value) unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs) log_jacobian = self.bijector("inverse_log_jacobian", value) - return unadjust_prob + log_jacobian + isneginf = self.equal_base(unadjust_prob, -np.inf) + return self.select_base(isneginf, unadjust_prob, unadjust_prob + log_jacobian) def _prob(self, value, *args, **kwargs): return self.exp(self._log_prob(value, *args, **kwargs))