|
|
|
@ -212,6 +212,7 @@ class Poisson(Distribution):
|
|
|
|
|
"""
|
|
|
|
|
value = self._check_value(value, "value")
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
value = self.floor(value)
|
|
|
|
|
rate = self._check_param_type(rate)
|
|
|
|
|
log_rate = self.log(rate)
|
|
|
|
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
|
|
@ -239,6 +240,7 @@ class Poisson(Distribution):
|
|
|
|
|
"""
|
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
value = self.floor(value)
|
|
|
|
|
rate = self._check_param_type(rate)
|
|
|
|
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
|
|
|
comp = self.less(value, zeros)
|
|
|
|
@ -259,6 +261,9 @@ class Poisson(Distribution):
|
|
|
|
|
"""
|
|
|
|
|
shape = self.checktuple(shape, 'shape')
|
|
|
|
|
rate = self._check_param_type(rate)
|
|
|
|
|
|
|
|
|
|
# now Poisson sampler supports only fp32
|
|
|
|
|
rate = self.cast(rate, mstype.float32)
|
|
|
|
|
origin_shape = shape + self.shape(rate)
|
|
|
|
|
if origin_shape == ():
|
|
|
|
|
sample_shape = (1,)
|
|
|
|
|