!3756 Change distribution api

Merge pull request !3756 from XunDeng/pp_poc_v3
pull/3756/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit af9398b39a

File diff suppressed because it is too large Load Diff

@ -27,11 +27,7 @@ class Distribution(Cell):
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when
used inside a network. Arguments should be passed in through *args
in the form of function name followed by additional arguments.
Functions such as cdf and prob, require a value to be passed in while
functions such as mean and sd do not require arguments other than name.
and _log_prob. Arguments should be passed in through *args.
Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution.
@ -73,11 +69,6 @@ class Distribution(Cell):
self._set_log_survival()
self._set_cross_entropy()
self._prob_functions = ('prob', 'log_prob')
self._cdf_survival_functions = ('cdf', 'log_cdf', 'survival_function', 'log_survival')
self._variance_functions = ('var', 'sd')
self._divergence_functions = ('kl_loss', 'cross_entropy')
@property
def name(self):
return self._name
@ -185,7 +176,7 @@ class Distribution(Cell):
Evaluate the log probability(pdf or pmf) at the given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_prob(*args)
@ -204,7 +195,7 @@ class Distribution(Cell):
Evaluate the probability (pdf or pmf) at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_prob(*args)
@ -223,7 +214,7 @@ class Distribution(Cell):
Evaluate the cdf at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_cdf(*args)
@ -260,7 +251,7 @@ class Distribution(Cell):
Evaluate the log cdf at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_cdf(*args)
@ -279,7 +270,7 @@ class Distribution(Cell):
Evaluate the survival function at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_survival(*args)
@ -307,7 +298,7 @@ class Distribution(Cell):
Evaluate the log survival function at given value.
Note:
Args must include name of the function and value.
Args must include value.
Dist_spec_args are optional.
"""
return self._call_log_survival(*args)
@ -326,7 +317,7 @@ class Distribution(Cell):
Evaluate the KL divergence, i.e. KL(a||b).
Note:
Args must include name of the function, type of the distribution, parameters of distribution b.
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return self._kl_loss(*args)
@ -336,7 +327,7 @@ class Distribution(Cell):
Evaluate the mean.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._mean(*args)
@ -345,7 +336,7 @@ class Distribution(Cell):
Evaluate the mode.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._mode(*args)
@ -354,7 +345,7 @@ class Distribution(Cell):
Evaluate the standard deviation.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._call_sd(*args)
@ -363,7 +354,7 @@ class Distribution(Cell):
Evaluate the variance.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._call_var(*args)
@ -390,7 +381,7 @@ class Distribution(Cell):
Evaluate the entropy.
Note:
Args must include the name of function. Dist_spec_args are optional.
Dist_spec_args are optional.
"""
return self._entropy(*args)
@ -399,7 +390,7 @@ class Distribution(Cell):
Evaluate the cross_entropy between distribution a and b.
Note:
Args must include name of the function, type of the distribution, parameters of distribution b.
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return self._call_cross_entropy(*args)
@ -421,13 +412,13 @@ class Distribution(Cell):
*args (list): arguments passed in through construct.
Note:
Args must include name of the function.
Shape of the sample and dist_spec_args are optional.
Shape of the sample is default to ().
Dist_spec_args are optional.
"""
return self._sample(*args)
def construct(self, *inputs):
def construct(self, name, *args):
"""
Override construct in Cell.
@ -437,35 +428,36 @@ class Distribution(Cell):
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.
Args:
*inputs (list): inputs[0] is always the name of the function.
"""
if inputs[0] == 'log_prob':
return self._call_log_prob(*inputs)
if inputs[0] == 'prob':
return self._call_prob(*inputs)
if inputs[0] == 'cdf':
return self._call_cdf(*inputs)
if inputs[0] == 'log_cdf':
return self._call_log_cdf(*inputs)
if inputs[0] == 'survival_function':
return self._call_survival(*inputs)
if inputs[0] == 'log_survival':
return self._call_log_survival(*inputs)
if inputs[0] == 'kl_loss':
return self._kl_loss(*inputs)
if inputs[0] == 'mean':
return self._mean(*inputs)
if inputs[0] == 'mode':
return self._mode(*inputs)
if inputs[0] == 'sd':
return self._call_sd(*inputs)
if inputs[0] == 'var':
return self._call_var(*inputs)
if inputs[0] == 'entropy':
return self._entropy(*inputs)
if inputs[0] == 'cross_entropy':
return self._call_cross_entropy(*inputs)
if inputs[0] == 'sample':
return self._sample(*inputs)
name (str): name of the function.
*args (list): list of arguments needed for the function.
"""
if name == 'log_prob':
return self._call_log_prob(*args)
if name == 'prob':
return self._call_prob(*args)
if name == 'cdf':
return self._call_cdf(*args)
if name == 'log_cdf':
return self._call_log_cdf(*args)
if name == 'survival_function':
return self._call_survival(*args)
if name == 'log_survival':
return self._call_log_survival(*args)
if name == 'kl_loss':
return self._kl_loss(*args)
if name == 'mean':
return self._mean(*args)
if name == 'mode':
return self._mode(*args)
if name == 'sd':
return self._call_sd(*args)
if name == 'var':
return self._call_var(*args)
if name == 'entropy':
return self._entropy(*args)
if name == 'cross_entropy':
return self._call_cross_entropy(*args)
if name == 'sample':
return self._sample(*args)
return None

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('prob', x_)
return self.b.prob(x_)
def test_pmf():
"""
@ -57,9 +55,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('log_prob', x_)
return self.b.log_prob(x_)
def test_log_likelihood():
"""
@ -81,9 +78,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('kl_loss', 'Bernoulli', x_)
return self.b.kl_loss('Bernoulli', x_)
def test_kl_loss():
"""
@ -107,9 +103,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
@ms_function
def construct(self):
return self.b('mean'), self.b('sd'), self.b('mode')
return self.b.mean(), self.b.sd(), self.b.mode()
def test_basics():
"""
@ -134,9 +129,8 @@ class Sampling(nn.Cell):
self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape
@ms_function
def construct(self, probs=None):
return self.b('sample', self.shape, probs)
return self.b.sample(self.shape, probs)
def test_sample():
"""
@ -155,9 +149,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('cdf', x_)
return self.b.cdf(x_)
def test_cdf():
"""
@ -171,7 +164,6 @@ def test_cdf():
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
class LogCDF(nn.Cell):
"""
Test class: log cdf of bernoulli distributions.
@ -180,9 +172,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('log_cdf', x_)
return self.b.log_cdf(x_)
def test_logcdf():
"""
@ -205,9 +196,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('survival_function', x_)
return self.b.survival_function(x_)
def test_survival():
"""
@ -230,9 +220,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('log_survival', x_)
return self.b.log_survival(x_)
def test_log_survival():
"""
@ -254,9 +243,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self):
return self.b('entropy')
return self.b.entropy()
def test_entropy():
"""
@ -277,12 +265,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
entropy = self.b('entropy')
kl_loss = self.b('kl_loss', 'Bernoulli', x_)
entropy = self.b.entropy()
kl_loss = self.b.kl_loss('Bernoulli', x_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.b('cross_entropy', 'Bernoulli', x_)
cross_entropy = self.b.cross_entropy('Bernoulli', x_)
return h_sum_kl - cross_entropy
def test_cross_entropy():

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('prob', x_)
return self.e.prob(x_)
def test_pdf():
"""
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('log_prob', x_)
return self.e.log_prob(x_)
def test_log_likelihood():
"""
@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.e = msd.Exponential([1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('kl_loss', 'Exponential', x_)
return self.e.kl_loss('Exponential', x_)
def test_kl_loss():
"""
@ -104,9 +100,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.e = msd.Exponential([0.5], dtype=dtype.float32)
@ms_function
def construct(self):
return self.e('mean'), self.e('sd'), self.e('mode')
return self.e.mean(), self.e.sd(), self.e.mode()
def test_basics():
"""
@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, rate=None):
return self.e('sample', self.shape, rate)
return self.e.sample(self.shape, rate)
def test_sample():
"""
@ -154,9 +148,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('cdf', x_)
return self.e.cdf(x_)
def test_cdf():
"""
@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('log_cdf', x_)
return self.e.log_cdf(x_)
def test_log_cdf():
"""
@ -202,9 +194,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('survival_function', x_)
return self.e.survival_function(x_)
def test_survival():
"""
@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.e('log_survival', x_)
return self.e.log_survival(x_)
def test_log_survival():
"""
@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
@ms_function
def construct(self):
return self.e('entropy')
return self.e.entropy()
def test_entropy():
"""
@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.e = msd.Exponential([1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
entropy = self.e('entropy')
kl_loss = self.e('kl_loss', 'Exponential', x_)
entropy = self.e.entropy()
kl_loss = self.e.kl_loss('Exponential', x_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.e('cross_entropy', 'Exponential', x_)
cross_entropy = self.e.cross_entropy('Exponential', x_)
return h_sum_kl - cross_entropy
def test_cross_entropy():

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('prob', x_)
return self.g.prob(x_)
def test_pmf():
"""
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('log_prob', x_)
return self.g.log_prob(x_)
def test_log_likelihood():
"""
@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('kl_loss', 'Geometric', x_)
return self.g.kl_loss('Geometric', x_)
def test_kl_loss():
"""
@ -106,9 +102,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
@ms_function
def construct(self):
return self.g('mean'), self.g('sd'), self.g('mode')
return self.g.mean(), self.g.sd(), self.g.mode()
def test_basics():
"""
@ -133,9 +128,8 @@ class Sampling(nn.Cell):
self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape
@ms_function
def construct(self, probs=None):
return self.g('sample', self.shape, probs)
return self.g.sample(self.shape, probs)
def test_sample():
"""
@ -154,9 +148,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('cdf', x_)
return self.g.cdf(x_)
def test_cdf():
"""
@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('log_cdf', x_)
return self.g.log_cdf(x_)
def test_logcdf():
"""
@ -202,9 +194,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('survival_function', x_)
return self.g.survival_function(x_)
def test_survival():
"""
@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.g('log_survival', x_)
return self.g.log_survival(x_)
def test_log_survival():
"""
@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self):
return self.g('entropy')
return self.g.entropy()
def test_entropy():
"""
@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
entropy = self.g('entropy')
kl_loss = self.g('kl_loss', 'Geometric', x_)
entropy = self.g.entropy()
kl_loss = self.g.kl_loss('Geometric', x_)
h_sum_kl = entropy + kl_loss
ans = self.g('cross_entropy', 'Geometric', x_)
ans = self.g.cross_entropy('Geometric', x_)
return h_sum_kl - ans
def test_cross_entropy():

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('prob', x_)
return self.n.prob(x_)
def test_pdf():
"""
@ -55,9 +53,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('log_prob', x_)
return self.n.log_prob(x_)
def test_log_likelihood():
"""
@ -79,9 +76,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
return self.n('kl_loss', 'Normal', x_, y_)
return self.n.kl_loss('Normal', x_, y_)
def test_kl_loss():
@ -113,9 +109,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
@ms_function
def construct(self):
return self.n('mean'), self.n('sd'), self.n('mode')
return self.n.mean(), self.n.sd(), self.n.mode()
def test_basics():
"""
@ -139,9 +134,8 @@ class Sampling(nn.Cell):
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, mean=None, sd=None):
return self.n('sample', self.shape, mean, sd)
return self.n.sample(self.shape, mean, sd)
def test_sample():
"""
@ -163,9 +157,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('cdf', x_)
return self.n.cdf(x_)
def test_cdf():
@ -187,9 +180,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('log_cdf', x_)
return self.n.log_cdf(x_)
def test_log_cdf():
"""
@ -210,9 +202,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('survival_function', x_)
return self.n.survival_function(x_)
def test_survival():
"""
@ -233,9 +224,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('log_survival', x_)
return self.n.log_survival(x_)
def test_log_survival():
"""
@ -256,9 +246,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self):
return self.n('entropy')
return self.n.entropy()
def test_entropy():
"""
@ -279,12 +268,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
entropy = self.n('entropy')
kl_loss = self.n('kl_loss', 'Normal', x_, y_)
entropy = self.n.entropy()
kl_loss = self.n.kl_loss('Normal', x_, y_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.n('cross_entropy', 'Normal', x_, y_)
cross_entropy = self.n.cross_entropy('Normal', x_, y_)
return h_sum_kl - cross_entropy
def test_cross_entropy():
@ -297,3 +285,40 @@ def test_cross_entropy():
diff = cross_entropy(mean, sd)
tol = 1e-6
assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
class Net(nn.Cell):
"""
Test class: expand single distribution instance to multiple graphs
by specifying the attributes.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('Normal', x_, y_)
prob = self.normal.prob(kl)
return prob
def test_multiple_graphs():
"""
Test multiple graphs case.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()

@ -1,62 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for new api of normal distribution"""
import numpy as np
from scipy import stats
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype
from mindspore import Tensor
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: new api of normal distribution.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
prob = self.normal.prob('prob', kl)
return prob
def test_new_api():
"""
Test new api of normal distribution.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()

@ -19,7 +19,6 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -32,9 +31,8 @@ class Prob(nn.Cell):
super(Prob, self).__init__()
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('prob', x_)
return self.u.prob(x_)
def test_pdf():
"""
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super(LogProb, self).__init__()
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('log_prob', x_)
return self.u.log_prob(x_)
def test_log_likelihood():
"""
@ -80,9 +77,8 @@ class KL(nn.Cell):
super(KL, self).__init__()
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
return self.u('kl_loss', 'Uniform', x_, y_)
return self.u.kl_loss('Uniform', x_, y_)
def test_kl_loss():
"""
@ -106,9 +102,8 @@ class Basics(nn.Cell):
super(Basics, self).__init__()
self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32)
@ms_function
def construct(self):
return self.u('mean'), self.u('sd')
return self.u.mean(), self.u.sd()
def test_basics():
"""
@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, low=None, high=None):
return self.u('sample', self.shape, low, high)
return self.u.sample(self.shape, low, high)
def test_sample():
"""
@ -155,9 +149,8 @@ class CDF(nn.Cell):
super(CDF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('cdf', x_)
return self.u.cdf(x_)
def test_cdf():
"""
@ -179,9 +172,8 @@ class LogCDF(nn.Cell):
super(LogCDF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('log_cdf', x_)
return self.u.log_cdf(x_)
class SF(nn.Cell):
"""
@ -191,9 +183,8 @@ class SF(nn.Cell):
super(SF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('survival_function', x_)
return self.u.survival_function(x_)
class LogSF(nn.Cell):
"""
@ -203,9 +194,8 @@ class LogSF(nn.Cell):
super(LogSF, self).__init__()
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.u('log_survival', x_)
return self.u.log_survival(x_)
class EntropyH(nn.Cell):
"""
@ -215,9 +205,8 @@ class EntropyH(nn.Cell):
super(EntropyH, self).__init__()
self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32)
@ms_function
def construct(self):
return self.u('entropy')
return self.u.entropy()
def test_entropy():
"""
@ -238,12 +227,11 @@ class CrossEntropy(nn.Cell):
super(CrossEntropy, self).__init__()
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
entropy = self.u('entropy')
kl_loss = self.u('kl_loss', 'Uniform', x_, y_)
entropy = self.u.entropy()
kl_loss = self.u.kl_loss('Uniform', x_, y_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.u('cross_entropy', 'Uniform', x_, y_)
cross_entropy = self.u.cross_entropy('Uniform', x_, y_)
return h_sum_kl - cross_entropy
def test_log_cdf():

@ -49,12 +49,12 @@ class BernoulliProb(nn.Cell):
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
def construct(self, value):
prob = self.b('prob', value)
log_prob = self.b('log_prob', value)
cdf = self.b('cdf', value)
log_cdf = self.b('log_cdf', value)
sf = self.b('survival_function', value)
log_sf = self.b('log_survival', value)
prob = self.b.prob(value)
log_prob = self.b.log_prob(value)
cdf = self.b.cdf(value)
log_cdf = self.b.log_cdf(value)
sf = self.b.survival_function(value)
log_sf = self.b.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob():
@ -75,12 +75,12 @@ class BernoulliProb1(nn.Cell):
self.b = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.b('prob', value, probs)
log_prob = self.b('log_prob', value, probs)
cdf = self.b('cdf', value, probs)
log_cdf = self.b('log_cdf', value, probs)
sf = self.b('survival_function', value, probs)
log_sf = self.b('log_survival', value, probs)
prob = self.b.prob(value, probs)
log_prob = self.b.log_prob(value, probs)
cdf = self.b.cdf(value, probs)
log_cdf = self.b.log_cdf(value, probs)
sf = self.b.survival_function(value, probs)
log_sf = self.b.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob1():
@ -103,8 +103,8 @@ class BernoulliKl(nn.Cell):
self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
kl1 = self.b1('kl_loss', 'Bernoulli', probs_b)
kl2 = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a)
kl1 = self.b1.kl_loss('Bernoulli', probs_b)
kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
return kl1 + kl2
def test_kl():
@ -127,8 +127,8 @@ class BernoulliCrossEntropy(nn.Cell):
self.b2 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
h1 = self.b1('cross_entropy', 'Bernoulli', probs_b)
h2 = self.b2('cross_entropy', 'Bernoulli', probs_b, probs_a)
h1 = self.b1.cross_entropy('Bernoulli', probs_b)
h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
return h1 + h2
def test_cross_entropy():
@ -150,11 +150,11 @@ class BernoulliBasics(nn.Cell):
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self):
mean = self.b('mean')
sd = self.b('sd')
var = self.b('var')
mode = self.b('mode')
entropy = self.b('entropy')
mean = self.b.mean()
sd = self.b.sd()
var = self.b.var()
mode = self.b.mode()
entropy = self.b.entropy()
return mean + sd + var + mode + entropy
def test_bascis():
@ -164,3 +164,28 @@ def test_bascis():
net = BernoulliBasics()
ans = net()
assert isinstance(ans, Tensor)
class BernoulliConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
self.b1 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.b('prob', value)
prob1 = self.b('prob', value, probs)
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2
def test_bernoulli_construct():
"""
Test probability function going through construct.
"""
net = BernoulliConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)

@ -50,12 +50,12 @@ class ExponentialProb(nn.Cell):
self.e = msd.Exponential(0.5, dtype=dtype.float32)
def construct(self, value):
prob = self.e('prob', value)
log_prob = self.e('log_prob', value)
cdf = self.e('cdf', value)
log_cdf = self.e('log_cdf', value)
sf = self.e('survival_function', value)
log_sf = self.e('log_survival', value)
prob = self.e.prob(value)
log_prob = self.e.log_prob(value)
cdf = self.e.cdf(value)
log_cdf = self.e.log_cdf(value)
sf = self.e.survival_function(value)
log_sf = self.e.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_exponential_prob():
@ -76,12 +76,12 @@ class ExponentialProb1(nn.Cell):
self.e = msd.Exponential(dtype=dtype.float32)
def construct(self, value, rate):
prob = self.e('prob', value, rate)
log_prob = self.e('log_prob', value, rate)
cdf = self.e('cdf', value, rate)
log_cdf = self.e('log_cdf', value, rate)
sf = self.e('survival_function', value, rate)
log_sf = self.e('log_survival', value, rate)
prob = self.e.prob(value, rate)
log_prob = self.e.log_prob(value, rate)
cdf = self.e.cdf(value, rate)
log_cdf = self.e.log_cdf(value, rate)
sf = self.e.survival_function(value, rate)
log_sf = self.e.log_survival(value, rate)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_exponential_prob1():
@ -104,8 +104,8 @@ class ExponentialKl(nn.Cell):
self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a):
kl1 = self.e1('kl_loss', 'Exponential', rate_b)
kl2 = self.e2('kl_loss', 'Exponential', rate_b, rate_a)
kl1 = self.e1.kl_loss('Exponential', rate_b)
kl2 = self.e2.kl_loss('Exponential', rate_b, rate_a)
return kl1 + kl2
def test_kl():
@ -128,8 +128,8 @@ class ExponentialCrossEntropy(nn.Cell):
self.e2 = msd.Exponential(dtype=dtype.float32)
def construct(self, rate_b, rate_a):
h1 = self.e1('cross_entropy', 'Exponential', rate_b)
h2 = self.e2('cross_entropy', 'Exponential', rate_b, rate_a)
h1 = self.e1.cross_entropy('Exponential', rate_b)
h2 = self.e2.cross_entropy('Exponential', rate_b, rate_a)
return h1 + h2
def test_cross_entropy():
@ -151,11 +151,11 @@ class ExponentialBasics(nn.Cell):
self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32)
def construct(self):
mean = self.e('mean')
sd = self.e('sd')
var = self.e('var')
mode = self.e('mode')
entropy = self.e('entropy')
mean = self.e.mean()
sd = self.e.sd()
var = self.e.var()
mode = self.e.mode()
entropy = self.e.entropy()
return mean + sd + var + mode + entropy
def test_bascis():
@ -165,3 +165,29 @@ def test_bascis():
net = ExponentialBasics()
ans = net()
assert isinstance(ans, Tensor)
class ExpConstruct(nn.Cell):
"""
Exponential distribution: going through construct.
"""
def __init__(self):
super(ExpConstruct, self).__init__()
self.e = msd.Exponential(0.5, dtype=dtype.float32)
self.e1 = msd.Exponential(dtype=dtype.float32)
def construct(self, value, rate):
prob = self.e('prob', value)
prob1 = self.e('prob', value, rate)
prob2 = self.e1('prob', value, rate)
return prob + prob1 + prob2
def test_exp_construct():
"""
Test probability function going through construct.
"""
net = ExpConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)

@ -50,12 +50,12 @@ class GeometricProb(nn.Cell):
self.g = msd.Geometric(0.5, dtype=dtype.int32)
def construct(self, value):
prob = self.g('prob', value)
log_prob = self.g('log_prob', value)
cdf = self.g('cdf', value)
log_cdf = self.g('log_cdf', value)
sf = self.g('survival_function', value)
log_sf = self.g('log_survival', value)
prob = self.g.prob(value)
log_prob = self.g.log_prob(value)
cdf = self.g.cdf(value)
log_cdf = self.g.log_cdf(value)
sf = self.g.survival_function(value)
log_sf = self.g.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob():
@ -76,12 +76,12 @@ class GeometricProb1(nn.Cell):
self.g = msd.Geometric(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.g('prob', value, probs)
log_prob = self.g('log_prob', value, probs)
cdf = self.g('cdf', value, probs)
log_cdf = self.g('log_cdf', value, probs)
sf = self.g('survival_function', value, probs)
log_sf = self.g('log_survival', value, probs)
prob = self.g.prob(value, probs)
log_prob = self.g.log_prob(value, probs)
cdf = self.g.cdf(value, probs)
log_cdf = self.g.log_cdf(value, probs)
sf = self.g.survival_function(value, probs)
log_sf = self.g.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob1():
@ -105,8 +105,8 @@ class GeometricKl(nn.Cell):
self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
kl1 = self.g1('kl_loss', 'Geometric', probs_b)
kl2 = self.g2('kl_loss', 'Geometric', probs_b, probs_a)
kl1 = self.g1.kl_loss('Geometric', probs_b)
kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
return kl1 + kl2
def test_kl():
@ -129,8 +129,8 @@ class GeometricCrossEntropy(nn.Cell):
self.g2 = msd.Geometric(dtype=dtype.int32)
def construct(self, probs_b, probs_a):
h1 = self.g1('cross_entropy', 'Geometric', probs_b)
h2 = self.g2('cross_entropy', 'Geometric', probs_b, probs_a)
h1 = self.g1.cross_entropy('Geometric', probs_b)
h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
return h1 + h2
def test_cross_entropy():
@ -152,11 +152,11 @@ class GeometricBasics(nn.Cell):
self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
def construct(self):
mean = self.g('mean')
sd = self.g('sd')
var = self.g('var')
mode = self.g('mode')
entropy = self.g('entropy')
mean = self.g.mean()
sd = self.g.sd()
var = self.g.var()
mode = self.g.mode()
entropy = self.g.entropy()
return mean + sd + var + mode + entropy
def test_bascis():
@ -166,3 +166,29 @@ def test_bascis():
net = GeometricBasics()
ans = net()
assert isinstance(ans, Tensor)
class GeoConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(GeoConstruct, self).__init__()
self.g = msd.Geometric(0.5, dtype=dtype.int32)
self.g1 = msd.Geometric(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.g('prob', value)
prob1 = self.g('prob', value, probs)
prob2 = self.g1('prob', value, probs)
return prob + prob1 + prob2
def test_geo_construct():
"""
Test probability function going through construct.
"""
net = GeoConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)

@ -50,12 +50,12 @@ class NormalProb(nn.Cell):
self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self, value):
prob = self.normal('prob', value)
log_prob = self.normal('log_prob', value)
cdf = self.normal('cdf', value)
log_cdf = self.normal('log_cdf', value)
sf = self.normal('survival_function', value)
log_sf = self.normal('log_survival', value)
prob = self.normal.prob(value)
log_prob = self.normal.log_prob(value)
cdf = self.normal.cdf(value)
log_cdf = self.normal.log_cdf(value)
sf = self.normal.survival_function(value)
log_sf = self.normal.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_normal_prob():
@ -77,12 +77,12 @@ class NormalProb1(nn.Cell):
self.normal = msd.Normal()
def construct(self, value, mean, sd):
prob = self.normal('prob', value, mean, sd)
log_prob = self.normal('log_prob', value, mean, sd)
cdf = self.normal('cdf', value, mean, sd)
log_cdf = self.normal('log_cdf', value, mean, sd)
sf = self.normal('survival_function', value, mean, sd)
log_sf = self.normal('log_survival', value, mean, sd)
prob = self.normal.prob(value, mean, sd)
log_prob = self.normal.log_prob(value, mean, sd)
cdf = self.normal.cdf(value, mean, sd)
log_cdf = self.normal.log_cdf(value, mean, sd)
sf = self.normal.survival_function(value, mean, sd)
log_sf = self.normal.log_survival(value, mean, sd)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_normal_prob1():
@ -106,8 +106,8 @@ class NormalKl(nn.Cell):
self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a):
kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b)
kl2 = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
kl1 = self.n1.kl_loss('Normal', mean_b, sd_b)
kl2 = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
return kl1 + kl2
def test_kl():
@ -132,8 +132,8 @@ class NormalCrossEntropy(nn.Cell):
self.n2 = msd.Normal(dtype=dtype.float32)
def construct(self, mean_b, sd_b, mean_a, sd_a):
h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b)
h2 = self.n2('cross_entropy', 'Normal', mean_b, sd_b, mean_a, sd_a)
h1 = self.n1.cross_entropy('Normal', mean_b, sd_b)
h2 = self.n2.cross_entropy('Normal', mean_b, sd_b, mean_a, sd_a)
return h1 + h2
def test_cross_entropy():
@ -157,10 +157,10 @@ class NormalBasics(nn.Cell):
self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self):
mean = self.n('mean')
sd = self.n('sd')
mode = self.n('mode')
entropy = self.n('entropy')
mean = self.n.mean()
sd = self.n.sd()
mode = self.n.mode()
entropy = self.n.entropy()
return mean + sd + mode + entropy
def test_bascis():
@ -170,3 +170,30 @@ def test_bascis():
net = NormalBasics()
ans = net()
assert isinstance(ans, Tensor)
class NormalConstruct(nn.Cell):
"""
Normal distribution: going through construct.
"""
def __init__(self):
super(NormalConstruct, self).__init__()
self.normal = msd.Normal(3.0, 4.0)
self.normal1 = msd.Normal()
def construct(self, value, mean, sd):
prob = self.normal('prob', value)
prob1 = self.normal('prob', value, mean, sd)
prob2 = self.normal1('prob', value, mean, sd)
return prob + prob1 + prob2
def test_normal_construct():
"""
Test probability function going through construct.
"""
net = NormalConstruct()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
mean = Tensor([0.0], dtype=dtype.float32)
sd = Tensor([1.0], dtype=dtype.float32)
ans = net(value, mean, sd)
assert isinstance(ans, Tensor)

@ -60,12 +60,12 @@ class UniformProb(nn.Cell):
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self, value):
prob = self.u('prob', value)
log_prob = self.u('log_prob', value)
cdf = self.u('cdf', value)
log_cdf = self.u('log_cdf', value)
sf = self.u('survival_function', value)
log_sf = self.u('log_survival', value)
prob = self.u.prob(value)
log_prob = self.u.log_prob(value)
cdf = self.u.cdf(value)
log_cdf = self.u.log_cdf(value)
sf = self.u.survival_function(value)
log_sf = self.u.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob():
@ -86,12 +86,12 @@ class UniformProb1(nn.Cell):
self.u = msd.Uniform(dtype=dtype.float32)
def construct(self, value, low, high):
prob = self.u('prob', value, low, high)
log_prob = self.u('log_prob', value, low, high)
cdf = self.u('cdf', value, low, high)
log_cdf = self.u('log_cdf', value, low, high)
sf = self.u('survival_function', value, low, high)
log_sf = self.u('log_survival', value, low, high)
prob = self.u.prob(value, low, high)
log_prob = self.u.log_prob(value, low, high)
cdf = self.u.cdf(value, low, high)
log_cdf = self.u.log_cdf(value, low, high)
sf = self.u.survival_function(value, low, high)
log_sf = self.u.log_survival(value, low, high)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob1():
@ -115,8 +115,8 @@ class UniformKl(nn.Cell):
self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a):
kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b)
kl2 = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a)
kl1 = self.u1.kl_loss('Uniform', low_b, high_b)
kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
return kl1 + kl2
def test_kl():
@ -141,8 +141,8 @@ class UniformCrossEntropy(nn.Cell):
self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a):
h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b)
h2 = self.u2('cross_entropy', 'Uniform', low_b, high_b, low_a, high_a)
h1 = self.u1.cross_entropy('Uniform', low_b, high_b)
h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a)
return h1 + h2
def test_cross_entropy():
@ -166,10 +166,10 @@ class UniformBasics(nn.Cell):
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
def construct(self):
mean = self.u('mean')
sd = self.u('sd')
var = self.u('var')
entropy = self.u('entropy')
mean = self.u.mean()
sd = self.u.sd()
var = self.u.var()
entropy = self.u.entropy()
return mean + sd + var + entropy
def test_bascis():
@ -179,3 +179,30 @@ def test_bascis():
net = UniformBasics()
ans = net()
assert isinstance(ans, Tensor)
class UniConstruct(nn.Cell):
"""
Unifrom distribution: going through construct.
"""
def __init__(self):
super(UniConstruct, self).__init__()
self.u = msd.Uniform(-4.0, 4.0)
self.u1 = msd.Uniform()
def construct(self, value, low, high):
prob = self.u('prob', value)
prob1 = self.u('prob', value, low, high)
prob2 = self.u1('prob', value, low, high)
return prob + prob1 + prob2
def test_uniform_construct():
"""
Test probability function going through construct.
"""
net = UniConstruct()
value = Tensor([-5.0, 0.0, 1.0, 5.0], dtype=dtype.float32)
low = Tensor([-1.0], dtype=dtype.float32)
high = Tensor([1.0], dtype=dtype.float32)
ans = net(value, low, high)
assert isinstance(ans, Tensor)

Loading…
Cancel
Save