@ -404,8 +404,18 @@ class Categorical(Distribution):
one of K possible categories , with the probability of each category
one of K possible categories , with the probability of each category
separately specified .
separately specified .
The probability mass function ( pmf ) is :
. . math : :
pmf ( k ; p_i ) = \prod_ { i = 1 } ^ { k } p_i ^ { [ x = i ] }
In the above equation :
* : math : ` [ x = i ] ` : it evaluates to 1 if : math : ` x == i ` , 0 otherwise .
Args :
Args :
logits ( list | numpy . ndarray | Variable ) : The logits input of categorical distribution .
logits ( list | numpy . ndarray | Variable ) : The logits input of categorical distribution . The data type is float32 .
Examples :
Examples :
. . code - block : : python
. . code - block : : python
@ -439,7 +449,7 @@ class Categorical(Distribution):
def __init__ ( self , logits ) :
def __init__ ( self , logits ) :
"""
"""
Args :
Args :
logits : A float32 tensor
logits (list | numpy . ndarray | Variable ) : The logits input of categorical distribution . The data type is float32 .
"""
"""
if self . _validate_args ( logits ) :
if self . _validate_args ( logits ) :
self . logits = logits
self . logits = logits
@ -450,7 +460,7 @@ class Categorical(Distribution):
""" The KL-divergence between two Categorical distributions.
""" The KL-divergence between two Categorical distributions.
Args :
Args :
other ( Categorical ) : instance of Categorical .
other ( Categorical ) : instance of Categorical . The data type is float32 .
Returns :
Returns :
Variable : kl - divergence between two Categorical distributions .
Variable : kl - divergence between two Categorical distributions .
@ -477,7 +487,7 @@ class Categorical(Distribution):
""" Shannon entropy in nats.
""" Shannon entropy in nats.
Returns :
Returns :
Variable : Shannon entropy of Categorical distribution .
Variable : Shannon entropy of Categorical distribution . The data type is float32 .
"""
"""
logits = self . logits - nn . reduce_max ( self . logits , dim = - 1 , keep_dim = True )
logits = self . logits - nn . reduce_max ( self . logits , dim = - 1 , keep_dim = True )
@ -495,10 +505,31 @@ class MultivariateNormalDiag(Distribution):
A multivariate normal ( also called Gaussian ) distribution parameterized by a mean vector
A multivariate normal ( also called Gaussian ) distribution parameterized by a mean vector
and a covariance matrix .
and a covariance matrix .
The probability density function ( pdf ) is :
. . math : :
pdf ( x ; loc , scale ) = \\frac { e ^ { - \\frac { | | y | | ^ 2 } { 2 } } } { Z }
where :
. . math : :
y = inv ( scale ) @ ( x - loc )
Z = ( 2 \\pi ) ^ { 0.5 k } | det ( scale ) |
In the above equation :
* : math : ` inv ` : denotes to take the inverse of the matrix .
* : math : ` @ ` : denotes matrix multiplication .
* : math : ` det ` : denotes to evaluate the determinant .
Args :
Args :
loc ( list | numpy . ndarray | Variable ) : The mean of multivariateNormal distribution .
loc ( list | numpy . ndarray | Variable ) : The mean of multivariateNormal distribution with shape : math : ` [ k ] ` .
scale ( list | numpy . ndarray | Variable ) : The positive definite diagonal covariance matrix of
The data type is float32 .
multivariateNormal distribution .
scale ( list | numpy . ndarray | Variable ) : The positive definite diagonal covariance matrix of multivariateNormal
distribution with shape : math : ` [ k , k ] ` . All elements are 0 except diagonal elements . The data type is
float32 .
Examples :
Examples :
. . code - block : : python
. . code - block : : python
@ -570,7 +601,7 @@ class MultivariateNormalDiag(Distribution):
""" Shannon entropy in nats.
""" Shannon entropy in nats.
Returns :
Returns :
Variable : Shannon entropy of Multivariate Normal distribution .
Variable : Shannon entropy of Multivariate Normal distribution . The data type is float32 .
"""
"""
entropy = 0.5 * (
entropy = 0.5 * (
@ -586,7 +617,7 @@ class MultivariateNormalDiag(Distribution):
other ( MultivariateNormalDiag ) : instance of Multivariate Normal .
other ( MultivariateNormalDiag ) : instance of Multivariate Normal .
Returns :
Returns :
Variable : kl - divergence between two Multivariate Normal distributions .
Variable : kl - divergence between two Multivariate Normal distributions . The data type is float32 .
"""
"""
assert isinstance ( other , MultivariateNormalDiag )
assert isinstance ( other , MultivariateNormalDiag )