|
|
|
@ -25,7 +25,7 @@ from ...common import dtype as mstype
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'MatMul', 'Moments']
|
|
|
|
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'LBeta', 'MatMul', 'Moments']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
@ -227,6 +227,9 @@ class LGamma(Cell):
|
|
|
|
|
when x is an integer less or equal to 0, return +inf
|
|
|
|
|
when x = +/- inf, return +inf
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The input tensor. Only float16, float32 are supported.
|
|
|
|
|
|
|
|
|
@ -346,6 +349,9 @@ class DiGamma(Cell):
|
|
|
|
|
|
|
|
|
|
digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
|
|
|
|
|
|
|
|
|
@ -609,6 +615,9 @@ class IGamma(Cell):
|
|
|
|
|
|
|
|
|
|
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend``
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have
|
|
|
|
|
the same dtype with `x`.
|
|
|
|
@ -679,6 +688,102 @@ class IGamma(Cell):
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LBeta(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
This is semantically equal to lgamma(x) + lgamma(y) - lgamma(x + y).
|
|
|
|
|
|
|
|
|
|
The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation
|
|
|
|
|
is catastrophic cancellation between the lgammas. This method avoids the numeric cancellation by explicitly
|
|
|
|
|
decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling
|
|
|
|
|
the large terms from the Striling analytically.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
|
|
|
|
|
the same dtype with `y`.
|
|
|
|
|
- **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have
|
|
|
|
|
the same dtype with `x`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, has the same dtype as `x` and `y`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
|
|
|
|
>>> input_y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32))
|
|
|
|
|
>>> lbeta = nn.LBeta()
|
|
|
|
|
>>> output = lbeta(input_a, input_x)
|
|
|
|
|
>>> print (output)
|
|
|
|
|
[-1.7917596 -4.094345 -12.000229 -14.754799]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(LBeta, self).__init__()
|
|
|
|
|
# const numbers
|
|
|
|
|
self.log_2pi = np.log(2 * np.pi)
|
|
|
|
|
self.minimax_coeff = [-0.165322962780713e-02,
|
|
|
|
|
0.837308034031215e-03,
|
|
|
|
|
-0.595202931351870e-03,
|
|
|
|
|
0.793650666825390e-03,
|
|
|
|
|
-0.277777777760991e-02,
|
|
|
|
|
0.833333333333333e-01]
|
|
|
|
|
|
|
|
|
|
# operations
|
|
|
|
|
self.log = P.Log()
|
|
|
|
|
self.log1p = P.Log1p()
|
|
|
|
|
self.less = P.Less()
|
|
|
|
|
self.select = P.Select()
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
self.dtype = P.DType()
|
|
|
|
|
self.lgamma = LGamma()
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
x_dtype = self.dtype(x)
|
|
|
|
|
y_dtype = self.dtype(y)
|
|
|
|
|
_check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
|
|
|
|
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name)
|
|
|
|
|
x_plus_y = x + y
|
|
|
|
|
boradcastto = P.BroadcastTo(self.shape(x_plus_y))
|
|
|
|
|
x = boradcastto(x)
|
|
|
|
|
y = boradcastto(y)
|
|
|
|
|
comp_less = self.less(x, y)
|
|
|
|
|
x_min = self.select(comp_less, x, y)
|
|
|
|
|
y_max = self.select(comp_less, y, x)
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _log_gamma_correction(x, minimax_coeff):
|
|
|
|
|
inverse_x = 1. / x
|
|
|
|
|
inverse_x_squared = inverse_x * inverse_x
|
|
|
|
|
accum = minimax_coeff[0]
|
|
|
|
|
for i in range(1, 6):
|
|
|
|
|
accum = accum * inverse_x_squared + minimax_coeff[i]
|
|
|
|
|
return accum * inverse_x
|
|
|
|
|
|
|
|
|
|
log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff)
|
|
|
|
|
log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff)
|
|
|
|
|
log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
|
|
|
|
|
|
|
|
|
|
# Two large arguments case: y >= x >= 8.
|
|
|
|
|
log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \
|
|
|
|
|
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
|
|
|
|
|
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
|
|
|
|
|
|
|
|
|
|
cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
|
|
|
|
|
correction = log_gamma_correction_y - log_gamma_correction_x_y
|
|
|
|
|
log_gamma_difference_big_y = correction + cancelled_stirling
|
|
|
|
|
|
|
|
|
|
# One large argument case: x < 8, y >= 8.
|
|
|
|
|
log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y
|
|
|
|
|
|
|
|
|
|
# Small arguments case: x <= y < 8.
|
|
|
|
|
log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max)
|
|
|
|
|
comp_xless8 = self.less(x_min, 8)
|
|
|
|
|
comp_yless8 = self.less(y_max, 8)
|
|
|
|
|
temp = self.select(comp_yless8, log_beta_small, log_beta_one_large)
|
|
|
|
|
return self.select(comp_xless8, temp, log_beta_two_large)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def get_broadcast_matmul_shape(x_shape, y_shape):
|
|
|
|
|
"""get broadcast_matmul shape"""
|
|
|
|
|