diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu index 2877413963..ed7754a7b9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu @@ -102,13 +102,15 @@ __global__ void MultinomialKernel(int seed, T *input, int num_sample, curandStat } template -void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, - size_t categories, cudaStream_t cuda_stream) { +void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output, + size_t distributions, size_t categories, cudaStream_t cuda_stream) { int RNG_seed = 0; - if (seed != 0) { + std::random_device rd; + if (seed2 != 0) { + RNG_seed = seed2; + } else if (seed != 0) { RNG_seed = seed; } else { - std::random_device rd; RNG_seed = static_cast(rd()); } int count = distributions * num_sample; @@ -117,8 +119,8 @@ void Multinomial(int seed, T *input, int num_sample, curandState *globalState, i return; } -template void Multinomial(int seed, float *input, int num_sample, curandState *globalState, int *output, - size_t distributions, size_t categories, cudaStream_t cuda_stream); +template void Multinomial(int seed, int seed2, float *input, int num_sample, curandState *globalState, + int *output, size_t distributions, size_t categories, cudaStream_t cuda_stream); template void CheckNonNeg(const size_t size, const float *input, float *output, cudaStream_t cuda_stream); template void CheckZero(const size_t distributions, const size_t categories, const float *input, float *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh index 2c4153c127..8830ce0ae0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh @@ -20,8 +20,8 @@ #include "runtime/device/gpu/cuda_common.h" template -void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, - size_t categories, cudaStream_t cuda_stream); +void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output, + size_t distributions, size_t categories, cudaStream_t cuda_stream); template void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); template diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h index 0738318d4e..ec12133cdc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h @@ -32,7 +32,13 @@ namespace kernel { template class MultinomialGpuKernel : public GpuKernel { public: - MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {} + MultinomialGpuKernel() + : input_size_0_(0), + output_size_(0), + distributions_(0), + workspace_size_(sizeof(curandState)), + seed_(0), + seed2_(0) {} ~MultinomialGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -52,7 +58,7 @@ class MultinomialGpuKernel : public GpuKernel { IntToSize(categories), 1, false, false, reinterpret_cast(stream_ptr)); NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories), reinterpret_cast(stream_ptr)); - Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), + Multinomial(seed_, seed2_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), IntToSize(categories), reinterpret_cast(stream_ptr)); return true; } @@ -87,6 +93,7 @@ class MultinomialGpuKernel : public GpuKernel { } workspace_size_ = output_size_; seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); + seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); InitSizeLists(); return true; } @@ -106,6 +113,7 @@ class MultinomialGpuKernel : public GpuKernel { size_t distributions_; size_t workspace_size_; int seed_; + int seed2_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 63d8bea947..aa70704771 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -33,7 +33,7 @@ class Categorical(Distribution): Args: probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities. seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None. - dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32. + dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32. name (str): The name of the distribution. Default: Categorical. Note: diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 9489991abf..e8b8479016 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -109,6 +109,11 @@ def is_same_type(inst, type_): """ return inst == type_ +@constexpr +def check_valid_dim(dim, name): + if dim not in (1, 2): + raise ValueError( + f"For {name}, inputs dim must be 1d or 2d") @constexpr def check_valid_type(data_type, value_type, name): diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index e0927317b7..6f6d6165d5 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -205,7 +205,7 @@ def poisson(shape, mean, seed=None): value = random_poisson(shape, mean) return value -def multinomial(inputs, num_sample, replacement=True, seed=0): +def multinomial(inputs, num_sample, replacement=True, seed=None): r""" Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of the input tensor. @@ -232,18 +232,18 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): """ shape = P.Shape() reshape = P.Reshape() - if inputs.dim() != 1 and inputs.dim() != 2: - const_utils.raise_value_error("inputs dim must be 1d or 2d") + const_utils.check_valid_dim(len(shape(inputs)), "multinomial") + seed1, seed2 = _get_seed(seed, "multinomial") if not replacement: if shape(inputs)[-1] < num_sample: const_utils.raise_value_error("num_sample must be less than shape(input)[-1] without replacement") n_dist = 1 if len(shape(inputs)) > 1: n_dist = shape(inputs)[-2] - random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],)) + random_uniform = P.UniformReal(seed1, seed2)((n_dist * shape(inputs)[-1],)) if n_dist != 1: random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1])) vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) _, indices = P.TopK()(vals, num_sample) return indices - return P.Multinomial(seed=seed)(inputs, num_sample) + return P.Multinomial(seed1, seed2)(inputs, num_sample) diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 44b03d2acc..e05335800d 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -433,8 +433,8 @@ class Multinomial(PrimitiveWithInfer): The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum. Args: - seed (int): Seed data is used as entropy source for Random number engines to generate pseudo-random numbers. - Must be non-negative. Default: 0. + seed (int): Random seed, must be non-negative. Default: 0. + seed2 (int): Random seed2, must be non-negative. Default: 0. Inputs: - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dimensions. @@ -450,10 +450,10 @@ class Multinomial(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, seed=0): + def __init__(self, seed=0, seed2=0): """init""" - Validator.check_value_type("seed", seed, [int], self.name) Validator.check_non_negative_int(seed, "seed", self.name) + Validator.check_non_negative_int(seed2, "seed2", self.name) self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) def __infer__(self, inputs, num_samples): diff --git a/tests/st/ops/gpu/test_multinomial_op.py b/tests/st/ops/gpu/test_multinomial_op.py index 07fc47f4b5..322a8fb85b 100644 --- a/tests/st/ops/gpu/test_multinomial_op.py +++ b/tests/st/ops/gpu/test_multinomial_op.py @@ -17,9 +17,20 @@ import numpy as np import pytest from mindspore.ops import composite as C import mindspore.context as context +import mindspore.nn as nn from mindspore import Tensor -context.set_context(device_target='GPU') +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + +class Net(nn.Cell): + def __init__(self, sample, replacement, seed=0): + super(Net, self).__init__() + self.sample = sample + self.replacement = replacement + self.seed = seed + + def construct(self, x): + return C.multinomial(x, self.sample, self.replacement, self.seed) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -27,9 +38,12 @@ context.set_context(device_target='GPU') def test_multinomial(): x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32)) x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32)) - out0 = C.multinomial(x0, 1, True) - out1 = C.multinomial(x0, 2, True) - out2 = C.multinomial(x1, 6, True) + net0 = Net(1, True, 20) + net1 = Net(2, True, 20) + net2 = Net(6, True, 20) + out0 = net0(x0) + out1 = net1(x0) + out2 = net2(x1) assert out0.asnumpy().shape == (1,) assert out1.asnumpy().shape == (2,) assert out2.asnumpy().shape == (2, 6)