!9092 Add LBeta op at nn level

From: @peixu_ren
Reviewed-by: 
Signed-off-by:
pull/9092/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit dcde1fc70a

@ -25,7 +25,7 @@ from ...common import dtype as mstype
from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'MatMul', 'Moments']
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'LBeta', 'MatMul', 'Moments']
@constexpr
@ -227,6 +227,9 @@ class LGamma(Cell):
when x is an integer less or equal to 0, return +inf
when x = +/- inf, return +inf
Supported Platforms:
``Ascend`` ``GPU``
Inputs:
- **input_x** (Tensor) - The input tensor. Only float16, float32 are supported.
@ -346,6 +349,9 @@ class DiGamma(Cell):
digamma(x) = digamma(1 - x) - pi * cot(pi * x)
Supported Platforms:
``Ascend`` ``GPU``
Inputs:
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
@ -609,6 +615,9 @@ class IGamma(Cell):
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
Supported Platforms:
``Ascend``
Inputs:
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have
the same dtype with `x`.
@ -679,6 +688,102 @@ class IGamma(Cell):
return output
class LBeta(Cell):
r"""
This is semantically equal to lgamma(x) + lgamma(y) - lgamma(x + y).
The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation
is catastrophic cancellation between the lgammas. This method avoids the numeric cancellation by explicitly
decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling
the large terms from the Striling analytically.
Supported Platforms:
``Ascend`` ``GPU``
Inputs:
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
the same dtype with `y`.
- **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have
the same dtype with `x`.
Outputs:
Tensor, has the same dtype as `x` and `y`.
Examples:
>>> input_x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> input_y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32))
>>> lbeta = nn.LBeta()
>>> output = lbeta(input_a, input_x)
>>> print (output)
[-1.7917596 -4.094345 -12.000229 -14.754799]
"""
def __init__(self):
super(LBeta, self).__init__()
# const numbers
self.log_2pi = np.log(2 * np.pi)
self.minimax_coeff = [-0.165322962780713e-02,
0.837308034031215e-03,
-0.595202931351870e-03,
0.793650666825390e-03,
-0.277777777760991e-02,
0.833333333333333e-01]
# operations
self.log = P.Log()
self.log1p = P.Log1p()
self.less = P.Less()
self.select = P.Select()
self.shape = P.Shape()
self.dtype = P.DType()
self.lgamma = LGamma()
def construct(self, x, y):
x_dtype = self.dtype(x)
y_dtype = self.dtype(y)
_check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name)
x_plus_y = x + y
boradcastto = P.BroadcastTo(self.shape(x_plus_y))
x = boradcastto(x)
y = boradcastto(y)
comp_less = self.less(x, y)
x_min = self.select(comp_less, x, y)
y_max = self.select(comp_less, y, x)
@constexpr
def _log_gamma_correction(x, minimax_coeff):
inverse_x = 1. / x
inverse_x_squared = inverse_x * inverse_x
accum = minimax_coeff[0]
for i in range(1, 6):
accum = accum * inverse_x_squared + minimax_coeff[i]
return accum * inverse_x
log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff)
log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff)
log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
# Two large arguments case: y >= x >= 8.
log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
correction = log_gamma_correction_y - log_gamma_correction_x_y
log_gamma_difference_big_y = correction + cancelled_stirling
# One large argument case: x < 8, y >= 8.
log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y
# Small arguments case: x <= y < 8.
log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max)
comp_xless8 = self.less(x_min, 8)
comp_yless8 = self.less(y_max, 8)
temp = self.select(comp_yless8, log_beta_small, log_beta_one_large)
return self.select(comp_xless8, temp, log_beta_two_large)
@constexpr
def get_broadcast_matmul_shape(x_shape, y_shape):
"""get broadcast_matmul shape"""

@ -602,6 +602,11 @@ test_cases = [
'block': nn.DiGamma(),
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
'skip': ['backward']}),
('LBeta', {
'block': nn.LBeta(),
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32)),
Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
'skip': ['backward']}),
('FlattenNet', {
'block': FlattenNet(),
'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],

Loading…
Cancel
Save