update bernoull dist: clamp prob for log_prob/prob

fix doc
pull/12485/head
Zichun Ye 4 years ago
parent f9f24ca94d
commit b8fd0c196c

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from .distribution import Distribution
from ._utils.utils import check_prob, check_distribution_name
from ._utils.utils import check_prob, check_distribution_name, clamp_probs
from ._utils.custom_ops import exp_generic, log_generic
@ -86,7 +86,6 @@ class Bernoulli(Distribution):
>>> ans = b2.mean(probs_a)
>>> print(ans.shape)
(1,)
>>> print(ans.shape)
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
>>> # Args:
>>> # dist (str): the name of the distribution. Only 'Bernoulli' is supported.
@ -132,7 +131,8 @@ class Bernoulli(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
Validator.check_type_name(
"dtype", dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')
@ -241,6 +241,9 @@ class Bernoulli(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, self.parameter_type)
probs1 = self._check_param_type(probs1)
# clamp value for numerical stability
probs1 = clamp_probs(probs1)
probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
@ -266,8 +269,10 @@ class Bernoulli(Distribution):
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
comp_zero = self.less(value, 0.0)
comp_one = self.less(value, 1.0)
zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0)
ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0)
zeros = self.fill(self.parameter_type, self.shape(
broadcast_shape_tensor), 0.0)
ones = self.fill(self.parameter_type, self.shape(
broadcast_shape_tensor), 1.0)
less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones)

Loading…
Cancel
Save