|
|
|
@ -112,6 +112,7 @@ class Geometric(Distribution):
|
|
|
|
|
self.minval = np.finfo(np.float).tiny
|
|
|
|
|
|
|
|
|
|
# ops needed for the class
|
|
|
|
|
self.squeeze = P.Squeeze(0)
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.const = P.ScalarToArray()
|
|
|
|
|
self.dtypeop = P.DType()
|
|
|
|
@ -283,8 +284,16 @@ class Geometric(Distribution):
|
|
|
|
|
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
|
|
|
|
if probs1 is None:
|
|
|
|
|
raise_none_error("probs")
|
|
|
|
|
origin_shape = shape + self.shape(probs1)
|
|
|
|
|
if origin_shape == ():
|
|
|
|
|
sample_shape = (1,)
|
|
|
|
|
else:
|
|
|
|
|
sample_shape = origin_shape
|
|
|
|
|
minval = self.const(self.minval)
|
|
|
|
|
maxval = self.const(1.0)
|
|
|
|
|
sample_uniform = self.uniform(shape + self.shape(probs1), minval, maxval, self.seed)
|
|
|
|
|
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
|
|
|
|
|
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1))
|
|
|
|
|
return self.cast(sample, self.dtype)
|
|
|
|
|
value = self.cast(sample, self.dtype)
|
|
|
|
|
if origin_shape == ():
|
|
|
|
|
value = self.squeeze(value)
|
|
|
|
|
return value
|
|
|
|
|