|
|
|
@ -16,10 +16,11 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore._checkparam import Validator
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
|
|
|
|
|
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
|
|
|
|
|
check_distribution_name, raise_not_implemented_util
|
|
|
|
|
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
|
|
|
|
|
|
|
|
@ -107,7 +108,7 @@ class Categorical(Distribution):
|
|
|
|
|
param = dict(locals())
|
|
|
|
|
param['param_dict'] = {'probs': probs}
|
|
|
|
|
valid_dtype = mstype.int_type
|
|
|
|
|
check_type(dtype, valid_dtype, "Categorical")
|
|
|
|
|
Validator.check_type("Categorical", dtype, valid_dtype)
|
|
|
|
|
super(Categorical, self).__init__(seed, dtype, name, param)
|
|
|
|
|
|
|
|
|
|
self._probs = self._add_parameter(probs, 'probs')
|
|
|
|
|