parent
c1b9efe8e6
commit
877b561e77
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,249 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Test nn.probability.distribution.Categorical.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.nn.probability.distribution as msd
|
||||||
|
from mindspore import dtype
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def test_arguments():
|
||||||
|
"""
|
||||||
|
Args passing during initialization.
|
||||||
|
"""
|
||||||
|
c = msd.Categorical()
|
||||||
|
assert isinstance(c, msd.Distribution)
|
||||||
|
c = msd.Categorical([0.1, 0.9], dtype=dtype.int32)
|
||||||
|
assert isinstance(c, msd.Distribution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_type():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
msd.Categorical([0.1], dtype=dtype.bool_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_name():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
msd.Categorical([0.1], name=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
msd.Categorical([0.1], seed='seed')
|
||||||
|
|
||||||
|
|
||||||
|
def test_prob():
|
||||||
|
"""
|
||||||
|
Invalid probability.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical([-0.1], dtype=dtype.int32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical([1.1], dtype=dtype.int32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical([0.0], dtype=dtype.int32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical([1.0], dtype=dtype.int32)
|
||||||
|
|
||||||
|
def test_categorical_sum():
|
||||||
|
"""
|
||||||
|
Invaild probabilities.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32)
|
||||||
|
|
||||||
|
def rank():
|
||||||
|
"""
|
||||||
|
Rank dimenshion less than 1.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical(0.2, dtype=dtype.int32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.Categorical(Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32)
|
||||||
|
|
||||||
|
class CategoricalProb(nn.Cell):
|
||||||
|
"""
|
||||||
|
Categorical distribution: initialize with probs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CategoricalProb, self).__init__()
|
||||||
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
|
||||||
|
def construct(self, value):
|
||||||
|
prob = self.c.prob(value)
|
||||||
|
log_prob = self.c.log_prob(value)
|
||||||
|
cdf = self.c.cdf(value)
|
||||||
|
log_cdf = self.c.log_cdf(value)
|
||||||
|
sf = self.c.survival_function(value)
|
||||||
|
log_sf = self.c.log_survival(value)
|
||||||
|
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorical_prob():
|
||||||
|
"""
|
||||||
|
Test probability functions: passing value through construct.
|
||||||
|
"""
|
||||||
|
net = CategoricalProb()
|
||||||
|
value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32)
|
||||||
|
ans = net(value)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class CategoricalProb1(nn.Cell):
|
||||||
|
"""
|
||||||
|
Categorical distribution: initialize without probs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CategoricalProb1, self).__init__()
|
||||||
|
self.c = msd.Categorical(dtype=dtype.int32)
|
||||||
|
|
||||||
|
def construct(self, value, probs):
|
||||||
|
prob = self.c.prob(value, probs)
|
||||||
|
log_prob = self.c.log_prob(value, probs)
|
||||||
|
cdf = self.c.cdf(value, probs)
|
||||||
|
log_cdf = self.c.log_cdf(value, probs)
|
||||||
|
sf = self.c.survival_function(value, probs)
|
||||||
|
log_sf = self.c.log_survival(value, probs)
|
||||||
|
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorical_prob1():
|
||||||
|
"""
|
||||||
|
Test probability functions: passing value/probs through construct.
|
||||||
|
"""
|
||||||
|
net = CategoricalProb1()
|
||||||
|
value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32)
|
||||||
|
probs = Tensor([0.3, 0.7], dtype=dtype.float32)
|
||||||
|
ans = net(value, probs)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class CategoricalKl(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: kl_loss between Categorical distributions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CategoricalKl, self).__init__()
|
||||||
|
self.c1 = msd.Categorical([0.2, 0.2, 0.6], dtype=dtype.int32)
|
||||||
|
self.c2 = msd.Categorical(dtype=dtype.int32)
|
||||||
|
|
||||||
|
def construct(self, probs_b, probs_a):
|
||||||
|
kl1 = self.c1.kl_loss('Categorical', probs_b)
|
||||||
|
kl2 = self.c2.kl_loss('Categorical', probs_b, probs_a)
|
||||||
|
return kl1 + kl2
|
||||||
|
|
||||||
|
|
||||||
|
def test_kl():
|
||||||
|
"""
|
||||||
|
Test kl_loss function.
|
||||||
|
"""
|
||||||
|
ber_net = CategoricalKl()
|
||||||
|
probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
|
||||||
|
probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
|
||||||
|
ans = ber_net(probs_b, probs_a)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class CategoricalCrossEntropy(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: cross_entropy of Categorical distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CategoricalCrossEntropy, self).__init__()
|
||||||
|
self.c1 = msd.Categorical([0.1, 0.7, 0.2], dtype=dtype.int32)
|
||||||
|
self.c2 = msd.Categorical(dtype=dtype.int32)
|
||||||
|
|
||||||
|
def construct(self, probs_b, probs_a):
|
||||||
|
h1 = self.c1.cross_entropy('Categorical', probs_b)
|
||||||
|
h2 = self.c2.cross_entropy('Categorical', probs_b, probs_a)
|
||||||
|
return h1 + h2
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_entropy():
|
||||||
|
"""
|
||||||
|
Test cross_entropy between Categorical distributions.
|
||||||
|
"""
|
||||||
|
net = CategoricalCrossEntropy()
|
||||||
|
probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
|
||||||
|
probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
|
||||||
|
ans = net(probs_b, probs_a)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class CategoricalConstruct(nn.Cell):
|
||||||
|
"""
|
||||||
|
Categorical distribution: going through construct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CategoricalConstruct, self).__init__()
|
||||||
|
self.c = msd.Categorical([0.1, 0.8, 0.1], dtype=dtype.int32)
|
||||||
|
self.c1 = msd.Categorical(dtype=dtype.int32)
|
||||||
|
|
||||||
|
def construct(self, value, probs):
|
||||||
|
prob = self.c('prob', value)
|
||||||
|
prob1 = self.c('prob', value, probs)
|
||||||
|
prob2 = self.c1('prob', value, probs)
|
||||||
|
return prob + prob1 + prob2
|
||||||
|
|
||||||
|
def test_categorical_construct():
|
||||||
|
"""
|
||||||
|
Test probability function going through construct.
|
||||||
|
"""
|
||||||
|
net = CategoricalConstruct()
|
||||||
|
value = Tensor([0, 1, 2, 0, 0], dtype=dtype.float32)
|
||||||
|
probs = Tensor([0.5, 0.4, 0.1], dtype=dtype.float32)
|
||||||
|
ans = net(value, probs)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class CategoricalBasics(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: basic mean/var/mode/entropy function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CategoricalBasics, self).__init__()
|
||||||
|
self.c = msd.Categorical([0.2, 0.7, 0.1], dtype=dtype.int32)
|
||||||
|
self.c1 = msd.Categorical(dtype=dtype.int32)
|
||||||
|
|
||||||
|
def construct(self, probs):
|
||||||
|
basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy()
|
||||||
|
basics2 = self.c1.mean(probs) + self.c1.var(probs) +\
|
||||||
|
self.c1.mode(probs) + self.c1.entropy(probs)
|
||||||
|
return basics1 + basics2
|
||||||
|
|
||||||
|
|
||||||
|
def test_basics():
|
||||||
|
"""
|
||||||
|
Test basics functionality of Categorical distribution.
|
||||||
|
"""
|
||||||
|
net = CategoricalBasics()
|
||||||
|
probs = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
|
||||||
|
ans = net(probs)
|
||||||
|
assert isinstance(ans, Tensor)
|
Loading…
Reference in new issue