|
|
|
@ -108,7 +108,7 @@ class Categorical(Distribution):
|
|
|
|
|
name="Categorical"):
|
|
|
|
|
param = dict(locals())
|
|
|
|
|
param['param_dict'] = {'probs': probs}
|
|
|
|
|
valid_dtype = mstype.int_type
|
|
|
|
|
valid_dtype = mstype.int_type + mstype.float_type
|
|
|
|
|
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
|
|
|
|
super(Categorical, self).__init__(seed, dtype, name, param)
|
|
|
|
|
|
|
|
|
@ -134,8 +134,8 @@ class Categorical(Distribution):
|
|
|
|
|
self.exp = exp_generic
|
|
|
|
|
self.expand_dim = P.ExpandDims()
|
|
|
|
|
self.fill = P.Fill()
|
|
|
|
|
self.floor = P.Floor()
|
|
|
|
|
self.gather = P.GatherNd()
|
|
|
|
|
self.issubclass = P.IsSubClass()
|
|
|
|
|
self.less = P.Less()
|
|
|
|
|
self.log = log_generic
|
|
|
|
|
self.log_softmax = P.LogSoftmax()
|
|
|
|
@ -153,6 +153,7 @@ class Categorical(Distribution):
|
|
|
|
|
self.transpose = P.Transpose()
|
|
|
|
|
|
|
|
|
|
self.index_type = mstype.int32
|
|
|
|
|
self.nan = np.nan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
@ -255,7 +256,10 @@ class Categorical(Distribution):
|
|
|
|
|
probs (Tensor): Event probabilities. Default: self.probs.
|
|
|
|
|
"""
|
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
|
value = self.cast(value, self.parameter_type)
|
|
|
|
|
if self.issubclass(self.dtype, mstype.float_):
|
|
|
|
|
value = self.cast(value, self.index_type)
|
|
|
|
|
else:
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
|
logits = self.log(probs)
|
|
|
|
|
|
|
|
|
@ -294,8 +298,8 @@ class Categorical(Distribution):
|
|
|
|
|
# index into logit_pmf, fill in out_of_bound places with -inf
|
|
|
|
|
# reshape into label shape N
|
|
|
|
|
logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
|
|
|
|
|
neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf)
|
|
|
|
|
logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
|
|
|
|
|
nan = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), self.nan)
|
|
|
|
|
logits_pmf = self.select(out_of_bound, nan, logits_pmf)
|
|
|
|
|
ans = self.reshape(logits_pmf, label_shape)
|
|
|
|
|
if drop_dim:
|
|
|
|
|
return self.squeeze(ans)
|
|
|
|
@ -310,8 +314,10 @@ class Categorical(Distribution):
|
|
|
|
|
probs (Tensor): Event probabilities. Default: self.probs.
|
|
|
|
|
"""
|
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
|
value = self.cast(value, self.parameter_type)
|
|
|
|
|
value = self.floor(value)
|
|
|
|
|
if self.issubclass(self.dtype, mstype.float_):
|
|
|
|
|
value = self.cast(value, self.index_type)
|
|
|
|
|
else:
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
|
|
|
|
|
|
# handle the case when value is of shape () and probs is a scalar batch
|
|
|
|
|