diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 604fa54494..30f127b761 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -158,6 +158,18 @@ def check_prob(p): if not comp.all(): raise ValueError('Probabilities should be less than one') +def check_sum_equal_one(probs): + prob_sum = np.sum(probs.asnumpy(), axis=-1) + comp = np.equal(np.ones(prob_sum.shape), prob_sum) + if not comp.all(): + raise ValueError('Probabilities for each category should sum to one for Categorical distribution.') + +def check_rank(probs): + """ + Used in categorical distribution. check Rank >=1. + """ + if probs.asnumpy().ndim == 0: + raise ValueError('probs for Categorical distribution must have rank >= 1.') def logits_to_probs(logits, is_binary=False): """ diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 970647b5dc..0fce365557 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -13,108 +13,150 @@ # limitations under the License. # ============================================================================ """Categorical Distribution""" +import numpy as np from mindspore.ops import operations as P +from mindspore.ops import composite as C import mindspore.nn as nn from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import logits_to_probs, probs_to_logits, check_type, cast_to_tensor, \ - raise_probs_logits_error +from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\ + check_distribution_name, raise_not_implemented_util +from ._utils.custom_ops import exp_generic, log_generic, broadcast_to class Categorical(Distribution): """ - Create a categorical distribution parameterized by either probabilities or logits (but not both). + Create a categorical distribution parameterized by event probabilities. Args: probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities. - logits (Tensor, list, numpy.ndarray, Parameter, float): Event log-odds. seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None. dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32. name (str): The name of the distribution. Default: Categorical. Note: - `probs` must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1. + `probs` must have rank at least 1, values are proper probabilities and sum to 1. Examples: - >>> # To initialize a Categorical distribution of prob is [0.5, 0.5] + >>> # To initialize a Categorical distribution of probs [0.5, 0.5] >>> import mindspore.nn.probability.distribution as msd >>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32) >>> - >>> # To use Categorical in a network + >>> # To use a Categorical distribution in a network >>> class net(Cell): >>> def __init__(self, probs): >>> super(net, self).__init__(): - >>> self.ca = msd.Categorical(probs=probs, dtype=mstype.int32) + >>> self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32) + >>> self.ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32) + >>> >>> # All the following calls in construct are valid >>> def construct(self, value): >>> - >>> # Similar calls can be made to logits - >>> ans = self.ca.probs - >>> # value must be Tensor(mstype.float32, bool, mstype.int32) - >>> ans = self.ca.log_prob(value) + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # probs (Tensor): event probabilities. Default: self.probs. + >>> + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing `prob` by the name of the function. + >>> ans = self.ca.prob(value) + >>> # Evaluate `prob` with respect to distribution b. + >>> ans = self.ca.prob(value, probs_b) + >>> # `probs` must be passed in during function calls. + >>> ans = self.ca1.prob(value, probs_a) + >>> + >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments. + >>> # Args: + >>> # probs (Tensor): event probabilities. Default: self.probs. >>> - >>> # Usage of enumerate_support - >>> ans = self.ca.enumerate_support() + >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar. + >>> ans = self.ca.mean() # return 0.8 + >>> ans = self.ca.mean(probs_b) + >>> # `probs` must be passed in during function calls. + >>> ans = self.ca1.mean(probs_a) >>> - >>> # Usage of entropy - >>> ans = self.ca.entropy() + >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: + >>> # Args: + >>> # dist (str): the name of the distribution. Only 'Categorical' is supported. + >>> # probs_b (Tensor): event probabilities of distribution b. + >>> # probs (Tensor): event probabilities of distribution a. Default: self.probs. >>> - >>> # Sample + >>> # Examples of kl_loss. `cross_entropy` is similar. + >>> ans = self.ca.kl_loss('Categorical', probs_b) + >>> ans = self.ca.kl_loss('Categorical', probs_b, probs_a) + >>> # An additional `probs` must be passed in. + >>> ans = self.ca1.kl_loss('Categorical', probs_b, probs_a) + >>> + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: (). + >>> # probs (Tensor): event probabilities. Default: self.probs. >>> ans = self.ca.sample() >>> ans = self.ca.sample((2,3)) - >>> ans = self.ca.sample((2,)) + >>> ans = self.b1.sample((2,3), probs_b) + >>> ans = self.b2.sample((2,3), probs_a) """ def __init__(self, probs=None, - logits=None, seed=None, dtype=mstype.int32, name="Categorical"): param = dict(locals()) - param['param_dict'] = {'probs': probs, 'logits': logits} + param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type check_type(dtype, valid_dtype, "Categorical") super(Categorical, self).__init__(seed, dtype, name, param) - if (probs is None) == (logits is None): - raise_probs_logits_error() - self.reduce_sum = P.ReduceSum(keep_dims=True) - self.reduce_sum1 = P.ReduceSum(keep_dims=False) - self.log = P.Log() - self.exp = P.Exp() - self.shape = P.Shape() - self.reshape = P.Reshape() - self.div = P.RealDiv() - self.size = P.Size() - self.mutinomial = P.Multinomial(seed=self.seed) + + self._probs = self._add_parameter(probs, 'probs') + if self.probs is not None: + check_rank(self.probs) + check_prob(self.probs) + check_sum_equal_one(self.probs) + + # update is_scalar_batch and broadcast_shape + # drop one dimension + if self.probs.shape[:-1] == (): + self._is_scalar_batch = True + self._broadcast_shape = self._broadcast_shape[:-1] + + self.argmax = P.Argmax() + self.broadcast = broadcast_to self.cast = P.Cast() - self.expandim = P.ExpandDims() - self.gather = P.GatherNd() + self.clip_by_value = C.clip_by_value self.concat = P.Concat(-1) + self.cumsum = P.CumSum() + self.dtypeop = P.DType() + self.exp = exp_generic + self.expand_dim = P.ExpandDims() + self.fill = P.Fill() + self.floor = P.Floor() + self.gather = P.GatherNd() + self.less = P.Less() + self.log = log_generic + self.log_softmax = P.LogSoftmax() + self.logicor = P.LogicalOr() + self.multinomial = P.Multinomial(seed=self.seed) + self.reshape = P.Reshape() + self.reduce_sum = P.ReduceSum(keep_dims=True) + self.select = P.Select() + self.shape = P.Shape() + self.softmax = P.Softmax() + self.squeeze = P.Squeeze() + self.square = P.Square() self.transpose = P.Transpose() - if probs is not None: - self._probs = cast_to_tensor(probs, mstype.float32) - input_sum = self.reduce_sum(self._probs, -1) - self._probs = self.div(self._probs, input_sum) - self._logits = probs_to_logits(self._probs) - self._param = self._probs - else: - self._logits = cast_to_tensor(logits, mstype.float32) - input_sum = self.reduce_sum(self.exp(self._logits), -1) - self._logits = self._logits - self.log(input_sum) - self._probs = logits_to_probs(self._logits) - self._param = self._logits - self._num_events = self.shape(self._param)[-1] - self._param2d = self.reshape(self._param, (-1, self._num_events)) - self._batch_shape = self.shape(self._param)[:-1] - self._batch_shape_n = (1,) * len(self._batch_shape) - @property - def logits(self): - """ - Return the logits. - """ - return self._logits + self.index_type = mstype.int32 + + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'probs = {self.probs}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info @property def probs(self): @@ -123,68 +165,214 @@ class Categorical(Distribution): """ return self._probs - def _sample(self, sample_shape=()): + def _mean(self, probs=None): + r""" + .. math:: + E[X] = \sum_{i=0}^{num_classes-1} i*p_i """ - Sampling. + probs = self._check_param_type(probs) + num_classes = self.shape(probs)[-1] + index = nn.Range(0., num_classes, 1.)() + return self.reduce_sum(index * probs, -1) + + def _mode(self, probs=None): + probs = self._check_param_type(probs) + mode = self.cast(self.argmax(probs), self.dtype) + return self.squeeze(mode) + + def _var(self, probs=None): + r""" + .. math:: + VAR(X) = E[X^{2}] - (E[X])^{2} + """ + probs = self._check_param_type(probs) + num_classes = self.shape(probs)[-1] + index = nn.Range(0., num_classes, 1.)() + return self.reduce_sum(self.square(index) * probs, -1) -\ + self.square(self.reduce_sum(index * probs, -1)) + + def _entropy(self, probs=None): + r""" + Evaluate entropy. + + .. math:: + H(X) = -\sum(logits * probs) + """ + probs = self._check_param_type(probs) + logits = self.log(probs) + return self.squeeze(-self.reduce_sum(logits * probs, -1)) + + def _kl_loss(self, dist, probs_b, probs=None): + """ + Evaluate KL divergence between Categorical distributions. Args: - sample_shape (tuple): The shape of the sample. Default: (). + dist (str): The type of the distributions. Should be "Categorical" in this case. + probs_b (Tensor): Event probabilities of distribution b. + probs (Tensor): Event probabilities of distribution a. Default: self.probs. + """ + check_distribution_name(dist, 'Categorical') + probs_b = self._check_value(probs_b, 'probs_b') + probs_b = self.cast(probs_b, self.parameter_type) + probs_a = self._check_param_type(probs) + logits_a = self.log(probs_a) + logits_b = self.log(probs_b) + return self.squeeze(-self.reduce_sum( + self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1)) - Returns: - Tensor, shape is shape(probs)[:-1] + sample_shape + def _cross_entropy(self, dist, probs_b, probs=None): """ - self.checktuple(sample_shape, 'shape') - num_sample = 1 - for i in sample_shape: - num_sample *= i - probs_2d = self.reshape(self._probs, (-1, self._num_events)) - samples = self.mutinomial(probs_2d, num_sample) - samples = self.transpose(samples, (1, 0)) - extend_shape = sample_shape - if len(self.shape(self._probs)) > 1: - extend_shape = sample_shape + self.shape(self._probs)[:-1] - return self.cast(self.reshape(samples, extend_shape), self.dtype) - - def _log_prob(self, value): + Evaluate cross entropy between Categorical distributions. + + Args: + dist (str): The type of the distributions. Should be "Categorical" in this case. + probs_b (Tensor): Event probabilities of distribution b. + probs (Tensor): Event probabilities of distribution a. Default: self.probs. + """ + check_distribution_name(dist, 'Categorical') + return self._entropy(probs) + self._kl_loss(dist, probs_b, probs) + + def _log_prob(self, value, probs=None): r""" Evaluate log probability. Args: value (Tensor): The value to be evaluated. + probs (Tensor): Event probabilities. Default: self.probs. """ value = self._check_value(value, 'value') - value = self.expandim(self.cast(value, mstype.float32), -1) - broad_shape = self.shape(value + self._logits) - broad = P.BroadcastTo(broad_shape) - logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1])) - value = self.reshape(broad(value)[..., :1], (-1, 1)) - index = nn.Range(0., self.shape(value)[0], 1)() - index = self.reshape(index, (-1, 1)) - value = self.concat((index, value)) - value = self.cast(value, mstype.int32) - return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1]) - - def _entropy(self): - r""" - Evaluate entropy. + value = self.cast(value, self.parameter_type) + probs = self._check_param_type(probs) + logits = self.log(probs) + + # handle the case when value is of shape () and probs is a scalar batch + drop_dim = False + if self.shape(value) == () and self.shape(probs)[:-1] == (): + drop_dim = True + # manually add one more dimension: () -> (1,) + # drop this dimension before return + value = self.expand_dim(value, -1) + + value = self.expand_dim(value, -1) + + broadcast_shape_tensor = logits * value + broadcast_shape = self.shape(broadcast_shape_tensor) + # broadcast_shape (N, C) + num_classes = broadcast_shape[-1] + label_shape = broadcast_shape[:-1] - .. math:: - H(X) = -\sum(logits * probs) - """ - p_log_p = self._logits * self._probs - return self.reduce_sum1(-p_log_p, -1) + # broadcasting logits and value + # logit_pmf shape (num of labels, C) + logits = self.broadcast(logits, broadcast_shape_tensor) + value = self.broadcast(value, broadcast_shape_tensor)[..., :1] - def enumerate_support(self, expand=True): + # flatten value to shape (number of labels, 1) + # clip value to be in range from 0 to num_classes -1 and cast into int32 + value = self.reshape(value, (-1, 1)) + out_of_bound = self.squeeze(self.logicor(\ + self.less(value, 0.0), self.less(num_classes-1, value))) + value_clipped = self.clip_by_value(value, 0.0, num_classes - 1) + value_clipped = self.cast(value_clipped, self.index_type) + # create index from 0 ... NumOfLabels + index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1)) + index = self.concat((index, value_clipped)) + + # index into logit_pmf, fill in out_of_bound places with -inf + # reshape into label shape N + logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index) + neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf) + logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf) + ans = self.reshape(logits_pmf, label_shape) + if drop_dim: + return self.squeeze(ans) + return ans + + def _cdf(self, value, probs=None): r""" - Enumerate categories. - - Args: - expand (Bool): Whether to expand. - """ - num_events = self._num_events - values = nn.Range(0., num_events, 1)() - values = self.reshape(values, (num_events,) + self._batch_shape_n) - if expand: - values = P.BroadcastTo((num_events,) + self._batch_shape)(values) - values = self.cast(values, mstype.int32) - return values + Cumulative distribution function (cdf) of Categorical distributions. + + Args: + value (Tensor): The value to be evaluated. + probs (Tensor): Event probabilities. Default: self.probs. + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.parameter_type) + value = self.floor(value) + probs = self._check_param_type(probs) + + # handle the case when value is of shape () and probs is a scalar batch + drop_dim = False + if self.shape(value) == () and self.shape(probs)[:-1] == (): + drop_dim = True + # manually add one more dimension: () -> (1,) + # drop this dimension before return + value = self.expand_dim(value, -1) + + value = self.expand_dim(value, -1) + + broadcast_shape_tensor = probs * value + broadcast_shape = self.shape(broadcast_shape_tensor) + # broadcast_shape (N, C) + num_classes = broadcast_shape[-1] + label_shape = broadcast_shape[:-1] + + probs = self.broadcast(probs, broadcast_shape_tensor) + value = self.broadcast(value, broadcast_shape_tensor)[..., :1] + + # flatten value to shape (number of labels, 1) + value = self.reshape(value, (-1, 1)) + + # drop one dimension to match cdf + # clip value to be in range from 0 to num_classes -1 and cast into int32 + less_than_zero = self.squeeze(self.less(value, 0.0)) + value_clipped = self.clip_by_value(value, 0.0, num_classes - 1) + value_clipped = self.cast(value_clipped, self.index_type) + + index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1)) + index = self.concat((index, value_clipped)) + + # reshape probs and fill less_than_zero places with 0 + probs = self.reshape(probs, (-1, num_classes)) + cdf = self.gather(self.cumsum(probs, 1), index) + zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) + cdf = self.select(less_than_zero, zeros, cdf) + cdf = self.reshape(cdf, label_shape) + + if drop_dim: + return self.squeeze(cdf) + return cdf + + def _sample(self, shape=(), probs=None): + """ + Sampling. + + Args: + shape (tuple): The shape of the sample. Default: (). + probs (Tensor): Event probabilities. Default: self.probs. + + Returns: + Tensor, shape is shape(probs)[:-1] + sample_shape + """ + if self.device_target == 'Ascend': + raise_not_implemented_util('On d backend, sample', self.name) + shape = self.checktuple(shape, 'shape') + probs = self._check_param_type(probs) + num_classes = self.shape(probs)[-1] + batch_shape = self.shape(probs)[:-1] + + sample_shape = shape + batch_shape + drop_dim = False + if sample_shape == (): + drop_dim = True + sample_shape = (1,) + + probs_2d = self.reshape(probs, (-1, num_classes)) + sample_tensor = self.fill(self.dtype, shape, 1.0) + sample_tensor = self.reshape(sample_tensor, (-1, 1)) + num_sample = self.shape(sample_tensor)[0] + samples = self.multinomial(probs_2d, num_sample) + samples = self.squeeze(self.transpose(samples, (1, 0))) + samples = self.cast(self.reshape(samples, sample_shape), self.dtype) + if drop_dim: + return self.squeeze(samples) + return samples diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 0547810c4e..ea3d8a3d43 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -96,6 +96,7 @@ class Distribution(Cell): self._set_cross_entropy() self.context_mode = context.get_context('mode') + self.device_target = context.get_context('device_target') self.checktuple = CheckTuple() self.checktensor = CheckTensor() self.broadcast = broadcast_to diff --git a/tests/st/probability/distribution/test_categorical.py b/tests/st/probability/distribution/test_categorical.py new file mode 100644 index 0000000000..7dd972749a --- /dev/null +++ b/tests/st/probability/distribution/test_categorical.py @@ -0,0 +1,273 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for cat distribution""" +import numpy as np +import pytest +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of categorical distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + return self.c.prob(x_) + +def test_pmf(): + """ + Test pmf. + """ + expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3] + pmf = Prob() + x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) + output = pmf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() + + +class LogProb(nn.Cell): + """ + Test class: log probability of categorical distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + return self.c.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pmf. + """ + expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3]) + logprob = LogProb() + x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss between categorical distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + return self.c.kl_loss('Categorical', x_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + kl_loss = KL() + output = kl_loss(Tensor([0.7, 0.3], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy()) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sampling of categorical distribution. + """ + def __init__(self): + super(Sampling, self).__init__() + self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32) + self.shape = (2, 3) + + def construct(self): + return self.c.sample(self.shape) + +def test_sample(): + """ + Test sample. + """ + with pytest.raises(NotImplementedError): + sample = Sampling() + sample() + +class Basics(nn.Cell): + """ + Test class: mean/var/mode of categorical distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32) + + def construct(self): + return self.c.mean(), self.c.var(), self.c.mode() + +def test_basics(): + """ + Test mean/variance/mode. + """ + basics = Basics() + mean, var, mode = basics() + expect_mean = 0 * 0.2 + 1 * 0.1 + 2 * 0.7 + expect_var = 0 * 0.2 + 1 * 0.1 + 4 * 0.7 - (expect_mean * expect_mean) + expect_mode = 2 + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(var.asnumpy() - expect_var) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + + +class CDF(nn.Cell): + """ + Test class: cdf of categorical distributions. + """ + def __init__(self): + super(CDF, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + return self.c.cdf(x_) + +def test_cdf(): + """ + Test cdf. + """ + expect_cdf = [0.7, 0.7, 1, 0.7, 1] + x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) + cdf = CDF() + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log cdf of categorical distributions. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + return self.c.log_cdf(x_) + +def test_logcdf(): + """ + Test log_cdf. + """ + expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1]) + x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) + logcdf = LogCDF() + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + + +class SF(nn.Cell): + """ + Test class: survival function of categorical distributions. + """ + def __init__(self): + super(SF, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + return self.c.survival_function(x_) + +def test_survival(): + """ + Test survival funciton. + """ + expect_survival = [0.3, 0., 0., 0.3, 0.3] + x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32) + sf = SF() + output = sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + + +class LogSF(nn.Cell): + """ + Test class: log survival function of categorical distributions. + """ + def __init__(self): + super(LogSF, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + return self.c.log_survival(x_) + +def test_log_survival(): + """ + Test log survival funciton. + """ + expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3]) + x_ = Tensor(np.array([-0.1, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32) + log_sf = LogSF() + output = log_sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of categorical distributions. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self): + return self.c.entropy() + +def test_entropy(): + """ + Test entropy. + """ + cat_benchmark = stats.multinomial(n=1, p=[0.7, 0.3]) + expect_entropy = cat_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between categorical distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, x_): + entropy = self.c.entropy() + kl_loss = self.c.kl_loss('Categorical', x_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.c.cross_entropy('Categorical', x_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + prob = Tensor([0.7, 0.3], dtype=dtype.float32) + diff = cross_entropy(prob) + tol = 1e-6 + assert (np.abs(diff.asnumpy()) < tol).all() diff --git a/tests/ut/python/nn/probability/distribution/test_categorical.py b/tests/ut/python/nn/probability/distribution/test_categorical.py new file mode 100644 index 0000000000..346bb369ff --- /dev/null +++ b/tests/ut/python/nn/probability/distribution/test_categorical.py @@ -0,0 +1,249 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Categorical. +""" +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + + +def test_arguments(): + """ + Args passing during initialization. + """ + c = msd.Categorical() + assert isinstance(c, msd.Distribution) + c = msd.Categorical([0.1, 0.9], dtype=dtype.int32) + assert isinstance(c, msd.Distribution) + + +def test_type(): + with pytest.raises(TypeError): + msd.Categorical([0.1], dtype=dtype.bool_) + + +def test_name(): + with pytest.raises(TypeError): + msd.Categorical([0.1], name=1.0) + + +def test_seed(): + with pytest.raises(TypeError): + msd.Categorical([0.1], seed='seed') + + +def test_prob(): + """ + Invalid probability. + """ + with pytest.raises(ValueError): + msd.Categorical([-0.1], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Categorical([1.1], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Categorical([0.0], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Categorical([1.0], dtype=dtype.int32) + +def test_categorical_sum(): + """ + Invaild probabilities. + """ + with pytest.raises(ValueError): + msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32) + +def rank(): + """ + Rank dimenshion less than 1. + """ + with pytest.raises(ValueError): + msd.Categorical(0.2, dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Categorical(Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32) + +class CategoricalProb(nn.Cell): + """ + Categorical distribution: initialize with probs. + """ + + def __init__(self): + super(CategoricalProb, self).__init__() + self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) + + def construct(self, value): + prob = self.c.prob(value) + log_prob = self.c.log_prob(value) + cdf = self.c.cdf(value) + log_cdf = self.c.log_cdf(value) + sf = self.c.survival_function(value) + log_sf = self.c.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + + +def test_categorical_prob(): + """ + Test probability functions: passing value through construct. + """ + net = CategoricalProb() + value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + + +class CategoricalProb1(nn.Cell): + """ + Categorical distribution: initialize without probs. + """ + + def __init__(self): + super(CategoricalProb1, self).__init__() + self.c = msd.Categorical(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.c.prob(value, probs) + log_prob = self.c.log_prob(value, probs) + cdf = self.c.cdf(value, probs) + log_cdf = self.c.log_cdf(value, probs) + sf = self.c.survival_function(value, probs) + log_sf = self.c.log_survival(value, probs) + return prob + log_prob + cdf + log_cdf + sf + log_sf + + +def test_categorical_prob1(): + """ + Test probability functions: passing value/probs through construct. + """ + net = CategoricalProb1() + value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32) + probs = Tensor([0.3, 0.7], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) + + +class CategoricalKl(nn.Cell): + """ + Test class: kl_loss between Categorical distributions. + """ + + def __init__(self): + super(CategoricalKl, self).__init__() + self.c1 = msd.Categorical([0.2, 0.2, 0.6], dtype=dtype.int32) + self.c2 = msd.Categorical(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + kl1 = self.c1.kl_loss('Categorical', probs_b) + kl2 = self.c2.kl_loss('Categorical', probs_b, probs_a) + return kl1 + kl2 + + +def test_kl(): + """ + Test kl_loss function. + """ + ber_net = CategoricalKl() + probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32) + probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32) + ans = ber_net(probs_b, probs_a) + assert isinstance(ans, Tensor) + + +class CategoricalCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Categorical distribution. + """ + + def __init__(self): + super(CategoricalCrossEntropy, self).__init__() + self.c1 = msd.Categorical([0.1, 0.7, 0.2], dtype=dtype.int32) + self.c2 = msd.Categorical(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + h1 = self.c1.cross_entropy('Categorical', probs_b) + h2 = self.c2.cross_entropy('Categorical', probs_b, probs_a) + return h1 + h2 + + +def test_cross_entropy(): + """ + Test cross_entropy between Categorical distributions. + """ + net = CategoricalCrossEntropy() + probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32) + probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32) + ans = net(probs_b, probs_a) + assert isinstance(ans, Tensor) + + +class CategoricalConstruct(nn.Cell): + """ + Categorical distribution: going through construct. + """ + + def __init__(self): + super(CategoricalConstruct, self).__init__() + self.c = msd.Categorical([0.1, 0.8, 0.1], dtype=dtype.int32) + self.c1 = msd.Categorical(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.c('prob', value) + prob1 = self.c('prob', value, probs) + prob2 = self.c1('prob', value, probs) + return prob + prob1 + prob2 + +def test_categorical_construct(): + """ + Test probability function going through construct. + """ + net = CategoricalConstruct() + value = Tensor([0, 1, 2, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5, 0.4, 0.1], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) + + +class CategoricalBasics(nn.Cell): + """ + Test class: basic mean/var/mode/entropy function. + """ + + def __init__(self): + super(CategoricalBasics, self).__init__() + self.c = msd.Categorical([0.2, 0.7, 0.1], dtype=dtype.int32) + self.c1 = msd.Categorical(dtype=dtype.int32) + + def construct(self, probs): + basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy() + basics2 = self.c1.mean(probs) + self.c1.var(probs) +\ + self.c1.mode(probs) + self.c1.entropy(probs) + return basics1 + basics2 + + +def test_basics(): + """ + Test basics functionality of Categorical distribution. + """ + net = CategoricalBasics() + probs = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32) + ans = net(probs) + assert isinstance(ans, Tensor)