refine Categorical and MultivariateNormalDiag en doc (#20723)

* refine Categorical and MultivariateNormalDiag en doc test=develop, test=document_fix

* refine Categorical and MultivariateNormalDiag en doc test=develop, test=document_fix
yaoxuefeng
Aurelius84 6 years ago committed by GitHub
parent dfa0549f87
commit 28dd2a58df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -404,8 +404,18 @@ class Categorical(Distribution):
one of K possible categories, with the probability of each category one of K possible categories, with the probability of each category
separately specified. separately specified.
The probability mass function (pmf) is:
.. math::
pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]}
In the above equation:
* :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise.
Args: Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32.
Examples: Examples:
.. code-block:: python .. code-block:: python
@ -439,7 +449,7 @@ class Categorical(Distribution):
def __init__(self, logits): def __init__(self, logits):
""" """
Args: Args:
logits: A float32 tensor logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32.
""" """
if self._validate_args(logits): if self._validate_args(logits):
self.logits = logits self.logits = logits
@ -450,7 +460,7 @@ class Categorical(Distribution):
"""The KL-divergence between two Categorical distributions. """The KL-divergence between two Categorical distributions.
Args: Args:
other (Categorical): instance of Categorical. other (Categorical): instance of Categorical. The data type is float32.
Returns: Returns:
Variable: kl-divergence between two Categorical distributions. Variable: kl-divergence between two Categorical distributions.
@ -477,7 +487,7 @@ class Categorical(Distribution):
"""Shannon entropy in nats. """Shannon entropy in nats.
Returns: Returns:
Variable: Shannon entropy of Categorical distribution. Variable: Shannon entropy of Categorical distribution. The data type is float32.
""" """
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
@ -495,10 +505,31 @@ class MultivariateNormalDiag(Distribution):
A multivariate normal (also called Gaussian) distribution parameterized by a mean vector A multivariate normal (also called Gaussian) distribution parameterized by a mean vector
and a covariance matrix. and a covariance matrix.
The probability density function (pdf) is:
.. math::
pdf(x; loc, scale) = \\frac{e^{-\\frac{||y||^2}{2}}}{Z}
where:
.. math::
y = inv(scale) @ (x - loc)
Z = (2\\pi)^{0.5k} |det(scale)|
In the above equation:
* :math:`inv` : denotes to take the inverse of the matrix.
* :math:`@` : denotes matrix multiplication.
* :math:`det` : denotes to evaluate the determinant.
Args: Args:
loc(list|numpy.ndarray|Variable): The mean of multivariateNormal distribution. loc(list|numpy.ndarray|Variable): The mean of multivariateNormal distribution with shape :math:`[k]` .
scale(list|numpy.ndarray|Variable): The positive definite diagonal covariance matrix of The data type is float32.
multivariateNormal distribution. scale(list|numpy.ndarray|Variable): The positive definite diagonal covariance matrix of multivariateNormal
distribution with shape :math:`[k, k]` . All elements are 0 except diagonal elements. The data type is
float32.
Examples: Examples:
.. code-block:: python .. code-block:: python
@ -570,7 +601,7 @@ class MultivariateNormalDiag(Distribution):
"""Shannon entropy in nats. """Shannon entropy in nats.
Returns: Returns:
Variable: Shannon entropy of Multivariate Normal distribution. Variable: Shannon entropy of Multivariate Normal distribution. The data type is float32.
""" """
entropy = 0.5 * ( entropy = 0.5 * (
@ -586,7 +617,7 @@ class MultivariateNormalDiag(Distribution):
other (MultivariateNormalDiag): instance of Multivariate Normal. other (MultivariateNormalDiag): instance of Multivariate Normal.
Returns: Returns:
Variable: kl-divergence between two Multivariate Normal distributions. Variable: kl-divergence between two Multivariate Normal distributions. The data type is float32.
""" """
assert isinstance(other, MultivariateNormalDiag) assert isinstance(other, MultivariateNormalDiag)

Loading…
Cancel
Save