diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h index aea841d376..436dc50d78 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h @@ -33,7 +33,12 @@ namespace kernel { template class MultinomialGpuKernel : public GpuKernel { public: - MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {} + MultinomialGpuKernel() + : input_size_0_(0), + output_size_(0), + distributions_(0), + workspace_size_(sizeof(curandState)), + replacement_(true) {} ~MultinomialGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -49,6 +54,19 @@ class MultinomialGpuKernel : public GpuKernel { int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_; int num_sample = SizeToInt(outputs[0]->size / sizeof(T)) / distributions_; // check input + T *cum_sum_input = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast(&cum_sum_input), input_size_0_), + "cudaMalloc failed."); + CheckPeram(input_addr, cum_sum_input, categories, stream_ptr); + if (replacement_) { + Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), + IntToSize(categories), reinterpret_cast(stream_ptr)); + } + CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed."); + return true; + } + + void CheckPeram(const T *input_addr, T *cum_sum_input, int categories, void *stream_ptr) { T *flag = nullptr; T *cflag = nullptr; CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast(&cflag), sizeof(T)), "cudaMalloc failed."); @@ -67,9 +85,6 @@ class MultinomialGpuKernel : public GpuKernel { if (*flag > 0) { MS_LOG(EXCEPTION) << "Input is invalid (input element < 0)"; } - T *cum_sum_input = nullptr; - CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast(&cum_sum_input), input_size_0_), - "cudaMalloc failed."); CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1, IntToSize(categories), 1, false, false, reinterpret_cast(stream_ptr)); CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), @@ -82,14 +97,10 @@ class MultinomialGpuKernel : public GpuKernel { if (*flag > 0) { MS_LOG(EXCEPTION) << "Input is invalid (sum <= 0)"; } - Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), - IntToSize(categories), reinterpret_cast(stream_ptr)); - - CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cflag), "cudaFree failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(flag), "cudaFreeHost failed."); - return true; } + bool Init(const CNodePtr &kernel_node) override { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); @@ -114,9 +125,15 @@ class MultinomialGpuKernel : public GpuKernel { } auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_size_ = sizeof(int); - for (size_t i = 0; i < output_shape.size(); i++) { - output_size_ *= output_shape[i]; - workspace_size_ *= output_shape[i]; + workspace_size_ = sizeof(int); + replacement_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("replacement")); + if (replacement_) { + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + } + if (replacement_) { + workspace_size_ = output_size_; } seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); InitSizeLists(); @@ -136,6 +153,7 @@ class MultinomialGpuKernel : public GpuKernel { size_t output_size_; size_t distributions_; size_t workspace_size_; + bool replacement_; int seed_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index d2a9bac1f1..c1d0bf52c9 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -20,8 +20,6 @@ from .. import functional as F from ..primitive import constexpr from .multitype_ops import _constexpr_utils as const_utils from ...common import dtype as mstype -from ..._checkparam import Validator as validator -from ..._checkparam import Rel # set graph-level RNG seed _GRAPH_SEED = 0 @@ -204,14 +202,13 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): Note: The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum. - Args: - seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. - Default: 0. - Inputs: - - **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. - - **num_samples** (int) - number of samples to draw. - - **replacement** (bool, optional) - whether to draw with replacement or not, default True. + Args: + input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. + num_samples (int) - number of samples to draw. + replacement (bool, optional) - whether to draw with replacement or not, default True. + seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers. + Must be non-negative. Default: 0. Outputs: Tensor. have the same rows with input, each row has num_samples sampled indices. @@ -222,21 +219,19 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): """ shape = P.Shape() reshape = P.Reshape() - validator.check_value_type('replacement', replacement, (bool,), None) - validator.check_value_type('num_sample', num_sample, (int,), None) - validator.check_integer("num_sample", num_sample, 0, Rel.GT, None) if inputs.dim() != 1 and inputs.dim() != 2: raise ValueError("inputs dim must be 1d or 2d") if not replacement: + P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample) if shape(inputs)[-1] < num_sample: raise ValueError("num_sample must be less than shape(input)[-1] without replacement") n_dist = 1 if len(shape(inputs)) > 1: n_dist = shape(inputs)[-2] - random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,)) + random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],)) if n_dist != 1: - random_uniform = reshape(random_uniform, (n_dist, num_sample)) + random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1])) vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) _, indices = P.TopK()(vals, num_sample) return indices - return P.Multinomial(seed=seed)(inputs, num_sample) + return P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample) diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 25cb20474b..2c915ba374 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -438,11 +438,12 @@ class Multinomial(PrimitiveWithInfer): but must be non-negative, finite and have a non-zero sum. Args: seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. - Default: 0. + Must be non-negative. Default: 0. + replacement(bool) - whether to draw with replacement or not. Inputs: - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims. - - **num_samples** (int) - number of samples to draw. + - **num_samples** (int32) - number of samples to draw. Outputs: Tensor. have the same rows with input, each row has num_samples sampled indices. @@ -450,13 +451,15 @@ class Multinomial(PrimitiveWithInfer): Examples: >>> input = Tensor([0., 9., 4., 0.], mstype.float32) >>> multinomial = P.Multinomial(seed=10) - >>> output = multinomial(input, 2) + >>> output = multinomial(input, 2, True) """ @prim_attr_register - def __init__(self, seed=0): + def __init__(self, replacement=True, seed=0): """init""" validator.check_value_type("seed", seed, [int], self.name) + validator.check_integer("seed", seed, 0, Rel.GE, self.name) + validator.check_value_type("replacement", replacement, [bool], self.name) self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) def __infer__(self, inputs, num_samples): @@ -467,7 +470,7 @@ class Multinomial(PrimitiveWithInfer): num_samples_value = num_samples["value"] if num_samples_value is None: raise ValueError(f"For {self.name}, shape nust be const") - validator.check_value_type("num_samples", num_samples_value, [int], self.name) + validator.check_value_type("num_samples", num_samples_value, (int,), self.name) validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None) y_shape = (num_samples_value,) if len(input_shape) == 2: