|
|
|
@ -28,7 +28,7 @@ class Categorical(Distribution):
|
|
|
|
|
probs (Tensor, list, numpy.ndarray, Parameter, float): event probabilities.
|
|
|
|
|
logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds.
|
|
|
|
|
seed (int): seed to use in sampling. Default: 0.
|
|
|
|
|
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
|
|
|
|
|
dtype (mstype.int32): type of the distribution. Default: mstype.int32.
|
|
|
|
|
name (str): name of the distribution. Default: Categorical.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
@ -49,7 +49,7 @@ class Categorical(Distribution):
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Similar calls can be made to logits
|
|
|
|
|
>>> ans = self.ca.probs
|
|
|
|
|
>>> # value should be Tensor
|
|
|
|
|
>>> # value should be Tensor(mstype.float32, bool, mstype.int32)
|
|
|
|
|
>>> ans = self.ca.log_prob(value)
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Usage of enumerate_support
|
|
|
|
@ -210,6 +210,8 @@ class Categorical(Distribution):
|
|
|
|
|
def enumerate_support(self, expand=True):
|
|
|
|
|
r"""
|
|
|
|
|
Enumerate categories.
|
|
|
|
|
Args:
|
|
|
|
|
expand (Bool): whether to expand.
|
|
|
|
|
"""
|
|
|
|
|
num_events = self._num_events
|
|
|
|
|
values = nn.Range(0., num_events, 1)()
|
|
|
|
|