!9808 Fix minor bugs in distribution classes

From: @shallydeng
Reviewed-by: @zichun_ye,@wang_zi_dong
Signed-off-by: @zichun_ye
pull/9808/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c08c284e01

@ -238,7 +238,8 @@ class Beta(Distribution):
comp1 = self.greater(concentration1, 1.)
comp2 = self.greater(concentration0, 1.)
cond = self.logicaland(comp1, comp2)
nan = self.fill(self.dtype, self.broadcast_shape, np.nan)
batch_shape = self.shape(concentration1 + concentration0)
nan = self.fill(self.dtype, batch_shape, np.nan)
mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
return self.select(cond, mode, nan)

@ -212,6 +212,7 @@ class Poisson(Distribution):
"""
value = self._check_value(value, "value")
value = self.cast(value, self.dtype)
value = self.floor(value)
rate = self._check_param_type(rate)
log_rate = self.log(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
@ -239,6 +240,7 @@ class Poisson(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
value = self.floor(value)
rate = self._check_param_type(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
comp = self.less(value, zeros)
@ -259,6 +261,9 @@ class Poisson(Distribution):
"""
shape = self.checktuple(shape, 'shape')
rate = self._check_param_type(rate)
# now Poisson sampler supports only fp32
rate = self.cast(rate, mstype.float32)
origin_shape = shape + self.shape(rate)
if origin_shape == ():
sample_shape = (1,)

Loading…
Cancel
Save