|
|
|
@ -20,7 +20,6 @@ from .. import functional as F
|
|
|
|
|
from ..primitive import constexpr
|
|
|
|
|
from .multitype_ops import _constexpr_utils as const_utils
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import check_int_positive
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
@ -134,9 +133,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0):
|
|
|
|
|
n_dist = 1
|
|
|
|
|
if len(shape(inputs)) > 1:
|
|
|
|
|
n_dist = shape(inputs)[-2]
|
|
|
|
|
a = Tensor(0.0, mstype.float32)
|
|
|
|
|
b = Tensor(1.0, mstype.float32)
|
|
|
|
|
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
|
|
|
|
|
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,))
|
|
|
|
|
if n_dist != 1:
|
|
|
|
|
random_uniform = reshape(random_uniform, (n_dist, num_sample))
|
|
|
|
|
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
|
|
|
|
|