|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import logits_to_probs, probs_to_logits, check_tensor_type, cast_to_tensor
|
|
|
|
|
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Categorical(Distribution):
|
|
|
|
@ -71,9 +71,11 @@ class Categorical(Distribution):
|
|
|
|
|
dtype=mstype.int32,
|
|
|
|
|
name="Categorical"):
|
|
|
|
|
param = dict(locals())
|
|
|
|
|
valid_dtype = mstype.int_type
|
|
|
|
|
check_type(dtype, valid_dtype, "Categorical")
|
|
|
|
|
super(Categorical, self).__init__(seed, dtype, name, param)
|
|
|
|
|
if (probs is None) == (logits is None):
|
|
|
|
|
raise ValueError("Either 'prob' or 'logits' must be specified, but not both.")
|
|
|
|
|
raise_probs_logits_error()
|
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
|
|
|
|
self.log = P.Log()
|
|
|
|
|
self.exp = P.Exp()
|
|
|
|
@ -127,8 +129,7 @@ class Categorical(Distribution):
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, shape is shape(probs)[:-1] + sample_shape
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(sample_shape, tuple):
|
|
|
|
|
raise ValueError("sample shape must be a tuple")
|
|
|
|
|
self.checktuple(sample_shape, 'shape')
|
|
|
|
|
num_sample = 1
|
|
|
|
|
for i in sample_shape:
|
|
|
|
|
num_sample *= i
|
|
|
|
@ -136,7 +137,7 @@ class Categorical(Distribution):
|
|
|
|
|
samples = self.mutinomial(probs_2d, num_sample)
|
|
|
|
|
extend_shape = sample_shape
|
|
|
|
|
if len(self.shape(self._probs)) > 1:
|
|
|
|
|
extend_shape = self.shape(self._probs)[:-1] + sample_shape
|
|
|
|
|
extend_shape = sample_shape + self.shape(self._probs)[:-1]
|
|
|
|
|
return self.cast(self.reshape(samples, extend_shape), self.dtype)
|
|
|
|
|
|
|
|
|
|
def _broad_cast_shape(self, a, b):
|
|
|
|
@ -183,15 +184,16 @@ class Categorical(Distribution):
|
|
|
|
|
if value is not None:
|
|
|
|
|
check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
|
|
|
|
|
value = self.expandim(self.cast(value, mstype.float32), -1)
|
|
|
|
|
broad_shape = self._broad_cast_shape(value, self._logits)
|
|
|
|
|
index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32))
|
|
|
|
|
index = self.expandim(index, -1)
|
|
|
|
|
logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0)
|
|
|
|
|
broad_shape = self._broad_cast_shape(value, logits)
|
|
|
|
|
broad = P.BroadcastTo(broad_shape)
|
|
|
|
|
value = broad(value)[..., :1]
|
|
|
|
|
index = cast_to_tensor(np.arange(broad_shape[-1]).astype(np.float32))
|
|
|
|
|
index = self.expandim(index, -1)
|
|
|
|
|
index = broad(index)[..., :1]
|
|
|
|
|
value = self.concat((index, value))
|
|
|
|
|
value = self.cast(value, mstype.int32)
|
|
|
|
|
return self.gather(self._logits, value)
|
|
|
|
|
return self.gather(logits, value)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _entropy(self):
|
|
|
|
@ -209,7 +211,7 @@ class Categorical(Distribution):
|
|
|
|
|
Enumerate categories.
|
|
|
|
|
"""
|
|
|
|
|
num_events = self._num_events
|
|
|
|
|
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.int32)
|
|
|
|
|
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32)
|
|
|
|
|
values = self.reshape(values, (num_events, 1))
|
|
|
|
|
if expand:
|
|
|
|
|
values = P.BroadcastTo((num_events, self._batch_shape))(values)
|
|
|
|
|