fix minor bug in catgorical distribution

pull/9050/head
Xun Deng 4 years ago
parent f4c126ddeb
commit a058881b90

@ -285,11 +285,11 @@ class CheckTuple(PrimitiveWithInfer):
return out
def __call__(self, x, name):
if context.get_context("mode") == 0:
return x["value"]
# Pynative mode
# The op is not used in a cell
if isinstance(x, tuple):
return x
if context.get_context("mode") == 0:
return x["value"]
raise TypeError(f"For {name}, input type should be a tuple.")

@ -108,7 +108,7 @@ class Categorical(Distribution):
name="Categorical"):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type
valid_dtype = mstype.int_type + mstype.float_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Categorical, self).__init__(seed, dtype, name, param)
@ -134,8 +134,8 @@ class Categorical(Distribution):
self.exp = exp_generic
self.expand_dim = P.ExpandDims()
self.fill = P.Fill()
self.floor = P.Floor()
self.gather = P.GatherNd()
self.issubclass = P.IsSubClass()
self.less = P.Less()
self.log = log_generic
self.log_softmax = P.LogSoftmax()
@ -153,6 +153,7 @@ class Categorical(Distribution):
self.transpose = P.Transpose()
self.index_type = mstype.int32
self.nan = np.nan
def extend_repr(self):
@ -255,7 +256,10 @@ class Categorical(Distribution):
probs (Tensor): Event probabilities. Default: self.probs.
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.parameter_type)
if self.issubclass(self.dtype, mstype.float_):
value = self.cast(value, self.index_type)
else:
value = self.cast(value, self.dtype)
probs = self._check_param_type(probs)
logits = self.log(probs)
@ -294,8 +298,8 @@ class Categorical(Distribution):
# index into logit_pmf, fill in out_of_bound places with -inf
# reshape into label shape N
logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf)
logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
nan = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), self.nan)
logits_pmf = self.select(out_of_bound, nan, logits_pmf)
ans = self.reshape(logits_pmf, label_shape)
if drop_dim:
return self.squeeze(ans)
@ -310,8 +314,10 @@ class Categorical(Distribution):
probs (Tensor): Event probabilities. Default: self.probs.
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.parameter_type)
value = self.floor(value)
if self.issubclass(self.dtype, mstype.float_):
value = self.cast(value, self.index_type)
else:
value = self.cast(value, self.dtype)
probs = self._check_param_type(probs)
# handle the case when value is of shape () and probs is a scalar batch

@ -219,7 +219,7 @@ def test_log_survival():
Test log survival funciton.
"""
expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
x_ = Tensor(np.array([-0.1, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32)
x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32)
log_sf = LogSF()
output = log_sf(x_)
tol = 1e-6

Loading…
Cancel
Save