add sample functions in normal and bermoulli distributions

pull/2605/head
peixu_ren 5 years ago committed by Xun Deng
parent 0aa26c1815
commit bef1fc7f19

@ -15,9 +15,9 @@
# ============================================================================
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import _utils as utils
from ....common.tensor import Tensor
from ....common.tensor import Tensor, MetaTensor
from ....common.parameter import Parameter
from ....common import dtype as mstype
@ -33,15 +33,17 @@ def cast_to_tensor(t, dtype=mstype.float32):
Cast an user input value into a Tensor of dtype.
Args:
t (int/float/list/numpy.ndarray/Tensor).
dtype (mindspore.dtype).
t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor.
dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32.
Raises:
RuntimeError: if t cannot be cast to Tensor.
Outputs:
Returns:
Tensor.
"""
if isinstance(t, Parameter):
return t
if isinstance(t, Tensor):
#check if the Tensor in shape of Tensor(4)
if t.dim() == 0:
@ -61,9 +63,9 @@ def calc_batch_size(batch_shape):
Calculate the size of a given batch_shape.
Args:
batch_shape (tuple)
batch_shape (tuple): batch shape to be calculated.
Outputs:
Returns:
int.
"""
return int(np.prod(batch_shape))
@ -73,23 +75,26 @@ def convert_to_batch(t, batch_shape, dtype):
Convert a Tensor to a given batch shape.
Args:
t (Tensor)
batch_shape (tuple)
dtype (mindspore.dtype)
t (Tensor, Parameter): Tensor to be converted.
batch_shape (tuple): desired batch shape.
dtype (mindspore.dtype): desired dtype.
Raises:
RuntimeError: if the converison cannot be done.
Outputs:
Returns:
Tensor, with shape of batch_shape.
"""
if isinstance(t, Parameter):
return t
t = cast_to_tensor(t, dtype)
reshape = P.Reshape()
if t.shape != batch_shape:
mul = calc_batch_size(batch_shape) // t.size()
if (calc_batch_size(batch_shape) % t.size()) != 0:
raise RuntimeError("Cannot cast the tensor to the given batch shape.")
temp = list(t.asnumpy()) * mul
return reshape(Tensor(temp), batch_shape)
temp = np.reshape(temp, batch_shape)
return Tensor(temp, dtype)
return t
def check_scalar_from_param(params):
@ -97,7 +102,7 @@ def check_scalar_from_param(params):
Check if params are all scalars.
Args:
params (dict): parameters used to initialized distribution.
params (dict): parameters used to initialize distribution.
Notes: String parameters are excluded.
"""
@ -116,9 +121,9 @@ def calc_broadcast_shape_from_param(params):
Calculate the broadcast shape from params.
Args:
params (dict): parameters used to initialized distribution.
params (dict): parameters used to initialize distribution.
Outputs:
Returns:
tuple.
"""
broadcast_shape = []
@ -127,7 +132,10 @@ def calc_broadcast_shape_from_param(params):
continue
if value is None:
return None
value_t = cast_to_tensor(value, params['dtype'])
if isinstance(value, Parameter):
value_t = value.default_input
else:
value_t = cast_to_tensor(value, params['dtype'])
broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape)
@ -136,36 +144,37 @@ def check_greater_equal_zero(value, name):
Check if the given Tensor is greater zero.
Args:
value (Tensor)
value (Tensor, Parameter): value to be checked.
name (str) : name of the value.
Raises:
ValueError: if the input value is less than zero.
"""
less = P.Less()
zeros = Tensor([0.0], dtype=value.dtype)
value = less(value, zeros)
if value.asnumpy().any():
raise ValueError('{} should be greater than zero.'.format(name))
if isinstance(value, Parameter):
if isinstance(value.default_input, MetaTensor):
return
value = value.default_input
comp = np.less(value.asnumpy(), np.zeros(value.shape))
if comp.any():
raise ValueError(f'{name} should be greater than zero.')
def check_greater(a, b, name_a, name_b):
"""
Check if Tensor b is strictly greater than Tensor a.
Args:
a (Tensor)
b (Tensor)
a (Tensor): input tensor a.
b (Tensor): input tensor b.
name_a (str): name of Tensor_a.
name_b (str): name of Tensor_b.
Raises:
ValueError: if b is less than or equal to a
"""
less = P.Less()
value = less(a, b)
if not value.asnumpy().all():
raise ValueError('{} should be less than {}'.format(name_a, name_b))
comp = np.less(a.asnumpy(), b.asnumpy())
if not comp.all():
raise ValueError(f'{name_a} should be less than {name_b}')
def check_prob(p):
@ -173,18 +182,18 @@ def check_prob(p):
Check if p is a proper probability, i.e. 0 <= p <=1.
Args:
p (Tensor): value to check.
p (Tensor, Parameter): value to be checked.
Raises:
ValueError: if p is not a proper probability.
"""
less = P.Less()
greater = P.Greater()
zeros = Tensor([0.0], dtype=p.dtype)
ones = Tensor([1.0], dtype=p.dtype)
comp = less(p, zeros)
if comp.asnumpy().any():
if isinstance(p, Parameter):
if isinstance(p.default_input, MetaTensor):
return
p = p.default_input
comp = np.less(p.asnumpy(), np.zeros(p.shape))
if comp.any():
raise ValueError('Probabilities should be greater than or equal to zero')
comp = greater(p, ones)
if comp.asnumpy().any():
comp = np.greater(p.asnumpy(), np.ones(p.shape))
if comp.any():
raise ValueError('Probabilities should be less than or equal to one')

@ -23,21 +23,24 @@ class Bernoulli(Distribution):
Example class: Bernoulli Distribution.
Args:
probs (int/float/list/numpy.ndarray/Tensor): probability of 1 as outcome.
dtype (mindspore.dtype): type of the distribution, default to int32.
probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
name (str): name of the distribution. Default: Bernoulli.
Note:
probs should be proper probabilities (0 <= p <= 1).
Examples:
>>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0
>>> b = nn.Bernoulli(0.5, dtype = dtype.int32)
>>> b = nn.Bernoulli(0.5, dtype = mstype.int32)
>>> # The following create two independent Bernoulli distributions
>>> b = nn.Bernoulli([0.7, 0.2], dtype = dtype.int32)
>>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32)
"""
def __init__(self,
probs=None,
seed=0,
dtype=mstype.int32,
name="Bernoulli"):
"""
@ -47,7 +50,6 @@ class Bernoulli(Distribution):
super(Bernoulli, self).__init__(dtype, name, param)
if probs is not None:
self._probs = cast_to_tensor(probs)
# check if the input probability is valid
check_prob(self._probs)
else:
self._probs = probs
@ -58,7 +60,17 @@ class Bernoulli(Distribution):
self.mul = P.Mul()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.shape = P.Shape()
self.const = P.ScalarToArray()
self.less = P.Less()
self.cast = P.Cast()
self.normal = P.Normal(seed=seed)
self.erf = P.Erf()
self.sqrt = P.Sqrt()
def extend_repr(self):
str_info = f'probs = {self._probs}'
return str_info
def probs(self):
"""
@ -66,21 +78,25 @@ class Bernoulli(Distribution):
"""
return self._probs
def _mean(self):
def _mean(self, name='mean', probs1=None):
r"""
.. math::
MEAN(B) = probs1
"""
if name == 'mean':
return self._probs if probs1 is None else probs1
return None
return self._probs
def _var(self):
def _var(self, name='var', probs1=None):
r"""
.. math::
VAR(B) = probs1 * probs0
"""
probs0 = self.add(1, -1 * self._probs)
return self.mul(probs0, self._probs)
if name in ('sd', 'var'):
probs1 = self._probs if probs1 is None else probs1
probs0 = self.add(1, -1 * probs1)
return self.mul(probs0, probs1)
return None
def _prob(self, name, value, probs=None):
r"""
@ -89,18 +105,20 @@ class Bernoulli(Distribution):
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default to self._probs.
probs (Tensor): probability of outcome is 1. Default: self._probs.
.. math::
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
probs1 = self._probs if probs is None else probs
probs0 = self.add(1, -1 * probs1)
return self.add(self.mul(probs1, value),
self.mul(probs0, self.add(1, -1 * value)))
if name in ('prob', 'log_prob'):
probs1 = self._probs if probs is None else probs
probs0 = self.add(1, -1 * probs1)
return self.add(self.mul(probs1, value),
self.mul(probs0, self.add(1, -1 * value)))
return None
def _kl_loss(self, name, dist, probs1_b):
def _kl_loss(self, name, dist, probs1_b, probs1_a=None):
r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
@ -108,19 +126,42 @@ class Bernoulli(Distribution):
name (str): name of the funtion. Should always be "kl_loss" when passed in from construct.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self._probs.
.. math::
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
"""
if dist == 'Bernoulli':
probs1_a = self._probs
if name == 'kl_loss' and dist == 'Bernoulli':
probs1_a = self._probs if probs1_a is None else probs1_a
probs0_a = self.add(1, -1 * probs1_a)
probs0_b = self.add(1, -1 * probs1_b)
return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)),
probs0_a * self.log(self.realdiv(probs0_a, probs0_b)))
return None
def extend_repr(self):
str_info = 'probs={}'.format(self._probs)
return str_info
def _sample(self, name, shape=(), probs=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self._probs.
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
probs1 = self._probs if probs is None else probs
batch_shape = self.shape(probs1)
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sqrt_two = self.sqrt(self.const(2.0))
sample_norm = self.normal(sample_shape, mean_zero, sd_one)
sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two)))
sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self._dtype)
return sample
return None

@ -21,6 +21,11 @@ class Distribution(Cell):
"""
Base class for all mathematical distributions.
Args:
dtype (mindspore.dtype): type of the distribution.
name (str): name of the distribution.
param (dict): parameters used to initialize the distribution.
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when
@ -97,14 +102,8 @@ class Distribution(Cell):
Note:
value is casted to Tensor for further calculation.
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distirbution. Default: self.mean.
sd (Tensor): standard deviation of the distribution. Default: self.sd.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._call_log_prob(*args)
@ -114,36 +113,9 @@ class Distribution(Cell):
.. math::
probability(x) = \exp(log_likehood(x))
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distribution. Default: self.mean.
sd (Tensor): standard deviation of the distritbuion. Default: self.sd.
"""
return self.exp(self._log_likelihood(*args))
def _call_prob(self, *args):
"""
Raises:
NotImplementedError when derived class didn't override _prob or _log_likelihood.
"""
raise NotImplementedError('pdf/pmf is not implemented: {}'.format(type(self).__name__))
def _call_log_prob(self, *args):
"""
Raises:
NotImplementedError when derived class didn't override _prob or _log_likelihood.
"""
raise NotImplementedError('log_probability is not implemented: {}'.format(type(self).__name__))
def _call_sd(self):
"""
Raises:
NotImplementedError when derived class didn't override _sd or _var.
"""
raise NotImplementedError('standard deviation is not implemented: {}'.format(type(self).__name__))
def prob(self, *args):
"""
Evaluate the prob (pdf or pmf) at given value.
@ -151,14 +123,8 @@ class Distribution(Cell):
Note:
value is casted to Tensor for further calculation.
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distribution.
sd (Tensor): standard deviation of the distritbuion.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._call_prob(*args)
@ -176,8 +142,8 @@ class Distribution(Cell):
Evaluate the KL divergence. Parameters of the second distribution should be
passed in through **kwargs.
Outputs:
Tensor, shape: broadcast_shape of the distribution and input distribution.
Returns:
Tensor, shape is the broadcast_shape of the distribution and input distribution.
"""
return self._kl_loss(**kwargs)
@ -185,8 +151,8 @@ class Distribution(Cell):
"""
Evaluate the mean.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._mean(**kwargs)
@ -194,19 +160,19 @@ class Distribution(Cell):
"""
Evaluate the standard deviation.
Outputs:
Tensor, with shape of broadcast_shape of the distribution.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._call_sd(**kwargs)
def _calc_sd_from_var(self, **kwargs):
def _calc_sd_from_var(self, *args):
r"""
Evaluate log probability from probability.
.. math::
STD(x) = \sqrt(VAR(x))
"""
return self.sqrt(self._var(**kwargs))
return self.sqrt(self._var(*args))
def construct(self, *inputs):
"""
@ -226,7 +192,9 @@ class Distribution(Cell):
if inputs[0] == 'kl_loss':
return self._kl_loss(*inputs)
if inputs[0] == 'mean':
return self._mean()
return self._mean(*inputs)
if inputs[0] == 'sd':
return self._call_sd()
return self._call_sd(*inputs)
if inputs[0] == 'sample':
return self._sample(*inputs)
return None

@ -25,23 +25,27 @@ class Normal(Distribution):
Example class: Normal distribution.
Args:
mean (int/float/list/numpy.ndarray/Tensor): mean of the Gaussian distribution
standard deviation (int/float/list/numpy.ndarray/Tensor): vairance of the Gaussian distribution
dtype (mindspore.dtype): type of the distribution
mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution.
sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
name (str): name of the distribution. Default: Normal.
Note:
Standard deviation should be greater than zero.
Examples:
>>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=dtype.float32)
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32)
>>> # The following create two independent normal distributions
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=dtype.float32)
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
"""
def __init__(self,
mean=None,
sd=None,
seed=0,
dtype=mstype.float32,
name="Normal"):
"""
@ -52,7 +56,6 @@ class Normal(Distribution):
if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype)
self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype)
#check validity of standard deviation
check_greater_equal_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
@ -61,11 +64,20 @@ class Normal(Distribution):
#ops needed for the class
self.exp = P.Exp()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sq = P.Square()
self.log = P.Log()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
self.normal = P.Normal(seed=seed)
self.shape = P.Shape()
self.zeroslike = P.ZerosLike()
self.const = P.ScalarToArray()
def extend_repr(self):
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
return str_info
def _expm1_by_step(self, x):
"""
@ -73,17 +85,23 @@ class Normal(Distribution):
"""
return self.add(self.exp(x), -1)
def _mean(self):
def _mean(self, name='mean', mean=None, sd=None):
"""
Mean of the distribution.
"""
return self._mean_value
if name == 'mean':
mean = self._mean_value if mean is None or sd is None else mean
return mean
return None
def _sd(self):
def _sd(self, name='sd', mean=None, sd=None):
"""
Standard deviation of the distribution.
"""
return self._sd_value
if name in ('sd', 'var'):
sd = self._sd_value if mean is None or sd is None else sd
return sd
return None
def _log_likelihood(self, name, value, mean=None, sd=None):
r"""
@ -92,33 +110,60 @@ class Normal(Distribution):
.. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)),
2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
return self.add(unnormalized_log_prob, neg_normalization)
def _kl_loss(self, name, dist, mean, sd):
if name in ('prob', 'log_prob'):
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)),
2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
return self.add(unnormalized_log_prob, neg_normalization)
return None
def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
r"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion passed in from construct. Should always be "kl_loss".
dist (str): type of the distributions. Should be "Normal" in this case.
mean (Tensor): mean of distribution b.
sd (Tensor): standard deviation distribution b.
mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
.. math::
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if dist == 'Normal':
diff_log_scale = self.add(self.log(self._sd_value), - self.log(sd))
squared_diff = self.sq(self.add(self.realdiv(self._mean_value, sd), - self.realdiv(mean, sd)))
if name == 'kl_loss' and dist == 'Normal':
mean_a = self._mean_value if mean_a is None else mean_a
sd_a = self._sd_value if sd_a is None else sd_a
diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b))
squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b)))
return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale)
return None
def extend_repr(self):
str_info = 'mean={}, standard deviation={}'.format(self._mean_value, self._sd_value)
return str_info
def _sample(self, name, shape=(), mean=None, sd=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
mean (Tensor): mean of the samples. Default: self._mean_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd)))
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sample_norm = self.normal(sample_shape, mean_zero, sd_one)
sample = self.add(mean, self.mul(sample_norm, sd))
return sample
return None

@ -65,12 +65,25 @@ class Net3(nn.Cell):
"""
def __init__(self):
super(Net3, self).__init__()
self.b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32)
self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32)
@ms_function
def construct(self):
return self.b('mean'), self.b('sd')
class Net4(nn.Cell):
"""
Test class: log probability of bernoulli distribution.
"""
def __init__(self, shape, seed=0):
super(Net4, self).__init__()
self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape
@ms_function
def construct(self, probs=None):
return self.b('sample', self.shape, probs)
def test_pmf():
"""
Test pmf.
@ -80,10 +93,8 @@ def test_pmf():
pdf = Net()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
output = pdf(x_)
print("expected_pmf: ", expect_pmf)
print("ans: ", output.asnumpy())
tol = 1e-6
assert (output.asnumpy() - expect_pmf < tol).all()
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
def test_log_likelihood():
"""
@ -94,10 +105,8 @@ def test_log_likelihood():
logprob = Net1()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
output = logprob(x_)
print("expected_log_probability: ", expect_logpmf)
print("ans: ", output.asnumpy())
tol = 1e-6
assert (output.asnumpy() - expect_logpmf < tol).all()
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
def test_kl_loss():
"""
@ -110,10 +119,8 @@ def test_kl_loss():
expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b)
kl_loss = Net2()
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
print("expected_kl_loss: ", expect_kl_loss)
print("ans: ", output.asnumpy())
tol = 1e-6
assert (output.asnumpy() - expect_kl_loss < tol).all()
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
def test_basics():
"""
@ -121,8 +128,20 @@ def test_basics():
"""
basics = Net3()
mean, sd = basics()
print("mean : ", mean)
print("sd : ", sd)
expect_mean = [0.5, 0.5]
assert (mean.asnumpy() == expect_mean).all()
assert (sd.asnumpy() == expect_mean).all()
b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32)
probs = b.probs()
print("probs is ", probs)
expect_probs = [0.7, 0.5]
tol = 1e-6
assert (np.abs(probs.asnumpy() - expect_probs) < tol).all()
def test_sample():
"""
Test sample.
"""
shape = (2, 3)
sample = Net4(shape)
output = sample()
assert output.shape == (2, 3, 2)

@ -65,12 +65,25 @@ class Net3(nn.Cell):
"""
def __init__(self):
super(Net3, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
@ms_function
def construct(self):
return self.n('mean'), self.n('sd')
class Net4(nn.Cell):
"""
Test class: mean/sd of normal distribution.
"""
def __init__(self, shape, seed=0):
super(Net4, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, mean=None, sd=None):
return self.n('sample', self.shape, mean, sd)
def test_pdf():
"""
Test pdf.
@ -79,10 +92,8 @@ def test_pdf():
expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32)
pdf = Net()
output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
print("expected_pdf: ", expect_pdf)
print("ans: ", output.asnumpy())
tol = 1e-6
assert (output.asnumpy() - expect_pdf < tol).all()
assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
def test_log_likelihood():
"""
@ -92,10 +103,8 @@ def test_log_likelihood():
expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
logprob = Net1()
output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
print("expected_log_probability: ", expect_logpdf)
print("ans: ", output.asnumpy())
tol = 1e-6
assert (output.asnumpy() - expect_logpdf < tol).all()
assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
def test_kl_loss():
"""
@ -115,10 +124,8 @@ def test_kl_loss():
mean = Tensor(mean_b, dtype=dtype.float32)
sd = Tensor(sd_b, dtype=dtype.float32)
output = kl_loss(mean, sd)
print("expected_kl_loss: ", expect_kl_loss)
print("ans: ", output.asnumpy())
tol = 1e-6
assert (output.asnumpy() - expect_kl_loss < tol).all()
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
def test_basics():
"""
@ -126,5 +133,20 @@ def test_basics():
"""
basics = Net3()
mean, sd = basics()
print("mean is ", mean)
print("sd is ", sd)
expect_mean = [3.0, 3.0]
expect_sd = [2.0, 4.0]
tol = 1e-6
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
def test_sample():
"""
Test sample.
"""
shape = (2, 3)
seed = 10
mean = Tensor([2.0], dtype=dtype.float32)
sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32)
sample = Net4(shape, seed=seed)
output = sample(mean, sd)
assert output.shape == (2, 3, 3)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save