Add LBeta op at nn level

pull/9092/head
peixu_ren 4 years ago
parent eb696440d0
commit 6355f46d32

@ -25,7 +25,7 @@ from ...common import dtype as mstype
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'MatMul', 'Moments'] __all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'LBeta', 'MatMul', 'Moments']
@constexpr @constexpr
@ -227,6 +227,9 @@ class LGamma(Cell):
when x is an integer less or equal to 0, return +inf when x is an integer less or equal to 0, return +inf
when x = +/- inf, return +inf when x = +/- inf, return +inf
Supported Platforms:
``Ascend`` ``GPU``
Inputs: Inputs:
- **input_x** (Tensor) - The input tensor. Only float16, float32 are supported. - **input_x** (Tensor) - The input tensor. Only float16, float32 are supported.
@ -346,6 +349,9 @@ class DiGamma(Cell):
digamma(x) = digamma(1 - x) - pi * cot(pi * x) digamma(x) = digamma(1 - x) - pi * cot(pi * x)
Supported Platforms:
``Ascend`` ``GPU``
Inputs: Inputs:
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. - **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
@ -609,6 +615,9 @@ class IGamma(Cell):
Above :math:`Q(a, x)` is the upper regularized complete Gamma function. Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
Supported Platforms:
``Ascend``
Inputs: Inputs:
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have - **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have
the same dtype with `x`. the same dtype with `x`.
@ -679,6 +688,102 @@ class IGamma(Cell):
return output return output
class LBeta(Cell):
r"""
This is semantically equal to lgamma(x) + lgamma(y) - lgamma(x + y).
The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation
is catastrophic cancellation between the lgammas. This method avoids the numeric cancellation by explicitly
decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling
the large terms from the Striling analytically.
Supported Platforms:
``Ascend`` ``GPU``
Inputs:
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
the same dtype with `y`.
- **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have
the same dtype with `x`.
Outputs:
Tensor, has the same dtype as `x` and `y`.
Examples:
>>> input_x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> input_y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32))
>>> lbeta = nn.LBeta()
>>> output = lbeta(input_a, input_x)
>>> print (output)
[-1.7917596 -4.094345 -12.000229 -14.754799]
"""
def __init__(self):
super(LBeta, self).__init__()
# const numbers
self.log_2pi = np.log(2 * np.pi)
self.minimax_coeff = [-0.165322962780713e-02,
0.837308034031215e-03,
-0.595202931351870e-03,
0.793650666825390e-03,
-0.277777777760991e-02,
0.833333333333333e-01]
# operations
self.log = P.Log()
self.log1p = P.Log1p()
self.less = P.Less()
self.select = P.Select()
self.shape = P.Shape()
self.dtype = P.DType()
self.lgamma = LGamma()
def construct(self, x, y):
x_dtype = self.dtype(x)
y_dtype = self.dtype(y)
_check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name)
x_plus_y = x + y
boradcastto = P.BroadcastTo(self.shape(x_plus_y))
x = boradcastto(x)
y = boradcastto(y)
comp_less = self.less(x, y)
x_min = self.select(comp_less, x, y)
y_max = self.select(comp_less, y, x)
@constexpr
def _log_gamma_correction(x, minimax_coeff):
inverse_x = 1. / x
inverse_x_squared = inverse_x * inverse_x
accum = minimax_coeff[0]
for i in range(1, 6):
accum = accum * inverse_x_squared + minimax_coeff[i]
return accum * inverse_x
log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff)
log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff)
log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
# Two large arguments case: y >= x >= 8.
log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
correction = log_gamma_correction_y - log_gamma_correction_x_y
log_gamma_difference_big_y = correction + cancelled_stirling
# One large argument case: x < 8, y >= 8.
log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y
# Small arguments case: x <= y < 8.
log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max)
comp_xless8 = self.less(x_min, 8)
comp_yless8 = self.less(y_max, 8)
temp = self.select(comp_yless8, log_beta_small, log_beta_one_large)
return self.select(comp_xless8, temp, log_beta_two_large)
@constexpr @constexpr
def get_broadcast_matmul_shape(x_shape, y_shape): def get_broadcast_matmul_shape(x_shape, y_shape):
"""get broadcast_matmul shape""" """get broadcast_matmul shape"""

@ -602,6 +602,11 @@ test_cases = [
'block': nn.DiGamma(), 'block': nn.DiGamma(),
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
'skip': ['backward']}), 'skip': ['backward']}),
('LBeta', {
'block': nn.LBeta(),
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32)),
Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
'skip': ['backward']}),
('FlattenNet', { ('FlattenNet', {
'block': FlattenNet(), 'block': FlattenNet(),
'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))], 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],

Loading…
Cancel
Save