diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h index e5d01b4ac6..aea841d376 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h @@ -54,11 +54,15 @@ class MultinomialGpuKernel : public GpuKernel { CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast(&cflag), sizeof(T)), "cudaMalloc failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&flag, sizeof(T)), "cudaMallocHost failed."); CalFloatStatus(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "cudaStreamSynchronize failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); if (*flag > 0) { MS_LOG(EXCEPTION) << "Input is invalid (containing NaN, -inf or inf)"; } CheckNonNeg(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "cudaStreamSynchronize failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); if (*flag > 0) { MS_LOG(EXCEPTION) << "Input is invalid (input element < 0)"; @@ -68,14 +72,16 @@ class MultinomialGpuKernel : public GpuKernel { "cudaMalloc failed."); CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1, IntToSize(categories), 1, false, false, reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "cudaStreamSynchronize failed."); CheckZero(IntToSize(distributions_), IntToSize(categories), cum_sum_input, cflag, reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "cudaStreamSynchronize failed."); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); if (*flag > 0) { MS_LOG(EXCEPTION) << "Input is invalid (sum <= 0)"; } - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), - "cudaStreamSynchronize failed."); Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), IntToSize(categories), reinterpret_cast(stream_ptr)); diff --git a/mindspore/nn/probability/distribution/__init__.py b/mindspore/nn/probability/distribution/__init__.py index ea6b743e29..bf30bae154 100644 --- a/mindspore/nn/probability/distribution/__init__.py +++ b/mindspore/nn/probability/distribution/__init__.py @@ -25,6 +25,7 @@ from .bernoulli import Bernoulli from .exponential import Exponential from .uniform import Uniform from .geometric import Geometric +from .categorical import Categorical __all__ = ['Distribution', 'TransformedDistribution', @@ -32,4 +33,5 @@ __all__ = ['Distribution', 'Bernoulli', 'Exponential', 'Uniform', + 'Categorical', 'Geometric',] diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 723ea8dc83..f1297545c2 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -256,16 +256,16 @@ def check_tensor_type(name, inputs, valid_type): Check if inputs is proper. Args: - inputs: Tensor to be checked. name: inputs name + inputs: Tensor to be checked. Raises: ValueError: if inputs is not a proper Tensor. """ if not isinstance(inputs, Tensor): raise TypeError(f"{name} should be a Tensor") - inputs = P.DType()(inputs) - if inputs not in valid_type: + input_type = P.DType()(inputs) + if input_type not in valid_type: raise TypeError(f"{name} dtype is invalid") def check_type(data_type, value_type, name): diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py new file mode 100644 index 0000000000..81b4152e9f --- /dev/null +++ b/mindspore/nn/probability/distribution/categorical.py @@ -0,0 +1,217 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Categorical Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import logits_to_probs, probs_to_logits, check_tensor_type, cast_to_tensor + + +class Categorical(Distribution): + """ + Creates a categorical distribution parameterized by either probs or logits (but not both). + + Args: + probs (Tensor, list, numpy.ndarray, Parameter, float): event probabilities. + logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. + name (str): name of the distribution. Default: Categorical. + + Note: + probs must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1. + + Examples: + >>> # To initialize a Categorical distribution of prob is [0.5, 0.5] + >>> import mindspore.nn.probability.distribution as msd + >>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32) + >>> + >>> # To use Categorical in a network + >>> class net(Cell): + >>> def __init__(self, probs): + >>> super(net, self).__init__(): + >>> self.ca = msd.Categorical(probs=probs, dtype=mstype.int32) + >>> # All the following calls in construct are valid + >>> def construct(self, value): + >>> + >>> # Similar calls can be made to logits + >>> ans = self.ca.probs + >>> # value should be Tensor + >>> ans = self.ca.log_prob(value) + >>> + >>> # Usage of enumerate_support + >>> ans = self.ca.enumerate_support() + >>> + >>> # Usage of entropy + >>> ans = self.ca.entropy() + >>> + >>> # Sample + >>> ans = self.ca.sample() + >>> ans = self.ca.sample((2,3)) + >>> ans = self.ca.sample((2,)) + """ + + def __init__(self, + probs=None, + logits=None, + seed=0, + dtype=mstype.int32, + name="Categorical"): + param = dict(locals()) + super(Categorical, self).__init__(seed, dtype, name, param) + if (probs is None) == (logits is None): + raise ValueError("Either 'prob' or 'logits' must be specified, but not both.") + self.reduce_sum = P.ReduceSum(keep_dims=True) + self.log = P.Log() + self.exp = P.Exp() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.div = P.RealDiv() + self.size = P.Size() + self.mutinomial = P.Multinomial(seed=seed) + self.cast = P.Cast() + self.expandim = P.ExpandDims() + self.gather = P.GatherNd() + self.concat = P.Concat(-1) + if probs is not None: + self._probs = cast_to_tensor(probs, mstype.float32) + input_sum = self.reduce_sum(self._probs, -1) + self._probs = self.div(self._probs, input_sum) + self._logits = probs_to_logits(self._probs) + self._param = self._probs + else: + self._logits = cast_to_tensor(logits, mstype.float32) + input_sum = self.reduce_sum(self.exp(self._logits), -1) + self._logits = self._logits - self.log(input_sum) + self._probs = logits_to_probs(self._logits) + self._param = self._logits + self._num_events = self.shape(self._param)[-1] + self._param2d = self.reshape(self._param, (-1, self._num_events)) + self._batch_shape = self.shape(self._param2d)[0] + + + @property + def logits(self): + """ + Returns the logits. + """ + return self._logits + + @property + def probs(self): + """ + Returns the probability. + """ + return self._probs + + def _sample(self, sample_shape=(1,)): + """ + Sampling. + + Args: + sample_shape (tuple): shape of the sample. Default: (1,). + + Returns: + Tensor, shape is shape(probs)[:-1] + sample_shape + """ + if not isinstance(sample_shape, tuple): + raise ValueError("sample shape must be a tuple") + num_sample = 1 + for i in sample_shape: + num_sample *= i + probs_2d = self.reshape(self._probs, (-1, self._num_events)) + samples = self.mutinomial(probs_2d, num_sample) + extend_shape = sample_shape + if len(self.shape(self._probs)) > 1: + extend_shape = self.shape(self._probs)[:-1] + sample_shape + return self.cast(self.reshape(samples, extend_shape), self.dtype) + + def _broad_cast_shape(self, a, b): + """ + Broadcast Tensor shape. + + Args: + a (Tensor): A Tensor need to Broadcast. + b (Tensor): Another Tensor need to Broadcast. + + Returns: + Tuple, Broadcast shape. + """ + shape_a = self.shape(a) + shape_b = self.shape(b) + size_a = len(shape_a) + size_b = len(shape_b) + if size_a > size_b: + size = size_a + shape_out = list(shape_a) + shape_short = list(shape_b) + diff_size = size_a - size_b + else: + size = size_b + shape_out = list(shape_b) + shape_short = list(shape_a) + diff_size = size_b - size_a + for i in range(diff_size, size): + if shape_out[i] == shape_short[i - diff_size]: + continue + if shape_out[i] == 1 or shape_short[i - diff_size] == 1: + shape_out[i] = shape_out[i] * shape_short[i - diff_size] + else: + raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.") + return tuple(shape_out) + + def _log_prob(self, value): + r""" + Evaluate log probability. + + Args: + value (Tensor): value to be evaluated. The dtype could be mstype.float32, bool, mstype.int32. + """ + if value is not None: + check_tensor_type("value", value, [mstype.float32, bool, mstype.int32]) + value = self.expandim(self.cast(value, mstype.float32), -1) + broad_shape = self._broad_cast_shape(value, self._logits) + broad = P.BroadcastTo(broad_shape) + value = broad(value)[..., :1] + index = cast_to_tensor(np.arange(broad_shape[-1]).astype(np.float32)) + index = self.expandim(index, -1) + index = broad(index)[..., :1] + value = self.concat((index, value)) + value = self.cast(value, mstype.int32) + return self.gather(self._logits, value) + return None + + def _entropy(self): + r""" + Evaluate entropy. + + .. math:: + H(X) = -\sum(logits * probs) + """ + p_log_p = self._logits * self._probs + return self.reduce_sum(-p_log_p, -1) + + def enumerate_support(self, expand=True): + r""" + Enumerate categories. + """ + num_events = self._num_events + values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.int32) + values = self.reshape(values, (num_events, 1)) + if expand: + values = P.BroadcastTo((num_events, self._batch_shape))(values) + values = self.cast(values, mstype.int32) + return values diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index 2d0dccf992..d2a9bac1f1 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -196,7 +196,7 @@ def poisson(shape, mean, seed=0): value = random_poisson(shape, mean) return value -def multinomial(inputs, num_sample=None, replacement=True, seed=0): +def multinomial(inputs, num_sample, replacement=True, seed=0): r""" Returns a tensor sampled from the multinomial probability distribution located in the corresponding row of tensor input. @@ -210,7 +210,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0): Inputs: - **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. - - **num_samples** (int) - number of samples to draw, default None. + - **num_samples** (int) - number of samples to draw. - **replacement** (bool, optional) - whether to draw with replacement or not, default True. Outputs: @@ -233,9 +233,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0): n_dist = 1 if len(shape(inputs)) > 1: n_dist = shape(inputs)[-2] - a = Tensor(0.0, mstype.float32) - b = Tensor(1.0, mstype.float32) - random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b) + random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,)) if n_dist != 1: random_uniform = reshape(random_uniform, (n_dist, num_sample)) vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) diff --git a/tests/st/ops/gpu/test_categorical_op.py b/tests/st/ops/gpu/test_categorical_op.py new file mode 100644 index 0000000000..9c124c5b20 --- /dev/null +++ b/tests/st/ops/gpu/test_categorical_op.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.context as context +import mindspore.nn.probability.distribution as msd + +context.set_context(device_target='GPU') + +def test_categorical1(): + cat1 = msd.Categorical(probs=[[0.9, 0.2], [0.9, 0.2]]) + cat1out1 = cat1.sample((1,)) + cat1out2 = cat1.sample((3, 2)) + cat1out3 = cat1.sample((6,)) + assert cat1out1.asnumpy().shape == (2, 1) + assert cat1out2.asnumpy().shape == (2, 3, 2) + assert cat1out3.asnumpy().shape == (2, 6) + + cat1 = msd.Categorical(probs=[0.9, 0.2]) + cat1out1 = cat1.sample((1,)) + cat1out2 = cat1.sample((3, 2)) + cat1out3 = cat1.sample((6,)) + assert cat1out1.asnumpy().shape == (1,) + assert cat1out2.asnumpy().shape == (3, 2) + assert cat1out3.asnumpy().shape == (6,)