diff --git a/mindspore/nn/probability/distribution/_utils/custom_ops.py b/mindspore/nn/probability/distribution/_utils/custom_ops.py index 430a311613..e8acac6a07 100644 --- a/mindspore/nn/probability/distribution/_utils/custom_ops.py +++ b/mindspore/nn/probability/distribution/_utils/custom_ops.py @@ -15,24 +15,30 @@ """Utitly functions to help distribution class.""" import numpy as np from mindspore.ops import operations as P +from mindspore.common import dtype as mstype def log_by_step(input_x): """ Log op on Ascend is calculated as log(abs(x)). Fix this with putting negative values as nan. """ - select = P.Select() log = P.Log() + less = P.Less() lessequal = P.LessEqual() fill = P.Fill() + cast = P.Cast() dtype = P.DType() shape = P.Shape() + select = P.Select() + input_x = cast(input_x, mstype.float32) + nan = fill(dtype(input_x), shape(input_x), np.nan) + inf = fill(dtype(input_x), shape(input_x), np.inf) + neg_x = less(input_x, 0.0) nonpos_x = lessequal(input_x, 0.0) log_x = log(input_x) - nan = fill(dtype(input_x), shape(input_x), np.nan) - result = select(nonpos_x, nan, log_x) - return result + result = select(nonpos_x, -inf, log_x) + return select(neg_x, nan, result) def log1p_by_step(x): """ diff --git a/tests/ut/python/nn/distribution/test_bernoulli.py b/tests/ut/python/nn/distribution/test_bernoulli.py index e04438f0a9..29fe10a844 100644 --- a/tests/ut/python/nn/distribution/test_bernoulli.py +++ b/tests/ut/python/nn/distribution/test_bernoulli.py @@ -157,51 +157,127 @@ def test_cross_entropy(): ans = net(probs_b, probs_a) assert isinstance(ans, Tensor) -class BernoulliBasics(nn.Cell): +class BernoulliConstruct(nn.Cell): + """ + Bernoulli distribution: going through construct. + """ + def __init__(self): + super(BernoulliConstruct, self).__init__() + self.b = msd.Bernoulli(0.5, dtype=dtype.int32) + self.b1 = msd.Bernoulli(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.b('prob', value) + prob1 = self.b('prob', value, probs) + prob2 = self.b1('prob', value, probs) + return prob + prob1 + prob2 + +def test_bernoulli_construct(): + """ + Test probability function going through construct. + """ + net = BernoulliConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) + +class BernoulliMean(nn.Cell): """ Test class: basic mean/sd/var/mode/entropy function. """ def __init__(self): - super(BernoulliBasics, self).__init__() + super(BernoulliMean, self).__init__() self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) def construct(self): mean = self.b.mean() + return mean + +def test_mean(): + """ + Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. + """ + net = BernoulliMean() + ans = net() + assert isinstance(ans, Tensor) + +class BernoulliSd(nn.Cell): + """ + Test class: basic mean/sd/var/mode/entropy function. + """ + def __init__(self): + super(BernoulliSd, self).__init__() + self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) + + def construct(self): sd = self.b.sd() + return sd + +def test_sd(): + """ + Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. + """ + net = BernoulliSd() + ans = net() + assert isinstance(ans, Tensor) + +class BernoulliVar(nn.Cell): + """ + Test class: basic mean/sd/var/mode/entropy function. + """ + def __init__(self): + super(BernoulliVar, self).__init__() + self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) + + def construct(self): var = self.b.var() + return var + +def test_var(): + """ + Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. + """ + net = BernoulliVar() + ans = net() + assert isinstance(ans, Tensor) + +class BernoulliMode(nn.Cell): + """ + Test class: basic mean/sd/var/mode/entropy function. + """ + def __init__(self): + super(BernoulliMode, self).__init__() + self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) + + def construct(self): mode = self.b.mode() - entropy = self.b.entropy() - return mean + sd + var + mode + entropy + return mode -def test_bascis(): +def test_mode(): """ Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. """ - net = BernoulliBasics() + net = BernoulliMode() ans = net() assert isinstance(ans, Tensor) -class BernoulliConstruct(nn.Cell): +class BernoulliEntropy(nn.Cell): """ - Bernoulli distribution: going through construct. + Test class: basic mean/sd/var/mode/entropy function. """ def __init__(self): - super(BernoulliConstruct, self).__init__() - self.b = msd.Bernoulli(0.5, dtype=dtype.int32) - self.b1 = msd.Bernoulli(dtype=dtype.int32) + super(BernoulliEntropy, self).__init__() + self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) - def construct(self, value, probs): - prob = self.b('prob', value) - prob1 = self.b('prob', value, probs) - prob2 = self.b1('prob', value, probs) - return prob + prob1 + prob2 + def construct(self): + entropy = self.b.entropy() + return entropy -def test_bernoulli_construct(): +def test_entropy(): """ - Test probability function going through construct. + Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. """ - net = BernoulliConstruct() - value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) - probs = Tensor([0.5], dtype=dtype.float32) - ans = net(value, probs) + net = BernoulliEntropy() + ans = net() assert isinstance(ans, Tensor)