From 6355f46d32d67defd1c2779429c7ff96f6be2afc Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 26 Nov 2020 11:20:39 -0500 Subject: [PATCH] Add LBeta op at nn level --- mindspore/nn/layer/math.py | 107 ++++++++++++++++++++++++++++- tests/ut/python/ops/test_nn_ops.py | 5 ++ 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 391e32b35f..ec759a5c15 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -25,7 +25,7 @@ from ...common import dtype as mstype from ..._checkparam import Validator as validator -__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'MatMul', 'Moments'] +__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'LBeta', 'MatMul', 'Moments'] @constexpr @@ -227,6 +227,9 @@ class LGamma(Cell): when x is an integer less or equal to 0, return +inf when x = +/- inf, return +inf + Supported Platforms: + ``Ascend`` ``GPU`` + Inputs: - **input_x** (Tensor) - The input tensor. Only float16, float32 are supported. @@ -346,6 +349,9 @@ class DiGamma(Cell): digamma(x) = digamma(1 - x) - pi * cot(pi * x) + Supported Platforms: + ``Ascend`` ``GPU`` + Inputs: - **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. @@ -609,6 +615,9 @@ class IGamma(Cell): Above :math:`Q(a, x)` is the upper regularized complete Gamma function. + Supported Platforms: + ``Ascend`` + Inputs: - **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have the same dtype with `x`. @@ -679,6 +688,102 @@ class IGamma(Cell): return output +class LBeta(Cell): + r""" + This is semantically equal to lgamma(x) + lgamma(y) - lgamma(x + y). + + The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation + is catastrophic cancellation between the lgammas. This method avoids the numeric cancellation by explicitly + decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling + the large terms from the Striling analytically. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Inputs: + - **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have + the same dtype with `y`. + - **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have + the same dtype with `x`. + + Outputs: + Tensor, has the same dtype as `x` and `y`. + + Examples: + >>> input_x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) + >>> input_y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32)) + >>> lbeta = nn.LBeta() + >>> output = lbeta(input_a, input_x) + >>> print (output) + [-1.7917596 -4.094345 -12.000229 -14.754799] + """ + + def __init__(self): + super(LBeta, self).__init__() + # const numbers + self.log_2pi = np.log(2 * np.pi) + self.minimax_coeff = [-0.165322962780713e-02, + 0.837308034031215e-03, + -0.595202931351870e-03, + 0.793650666825390e-03, + -0.277777777760991e-02, + 0.833333333333333e-01] + + # operations + self.log = P.Log() + self.log1p = P.Log1p() + self.less = P.Less() + self.select = P.Select() + self.shape = P.Shape() + self.dtype = P.DType() + self.lgamma = LGamma() + + def construct(self, x, y): + x_dtype = self.dtype(x) + y_dtype = self.dtype(y) + _check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name) + _check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name) + x_plus_y = x + y + boradcastto = P.BroadcastTo(self.shape(x_plus_y)) + x = boradcastto(x) + y = boradcastto(y) + comp_less = self.less(x, y) + x_min = self.select(comp_less, x, y) + y_max = self.select(comp_less, y, x) + + @constexpr + def _log_gamma_correction(x, minimax_coeff): + inverse_x = 1. / x + inverse_x_squared = inverse_x * inverse_x + accum = minimax_coeff[0] + for i in range(1, 6): + accum = accum * inverse_x_squared + minimax_coeff[i] + return accum * inverse_x + + log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff) + log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff) + log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff) + + # Two large arguments case: y >= x >= 8. + log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \ + + log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \ + + (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max) + + cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min + correction = log_gamma_correction_y - log_gamma_correction_x_y + log_gamma_difference_big_y = correction + cancelled_stirling + + # One large argument case: x < 8, y >= 8. + log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y + + # Small arguments case: x <= y < 8. + log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max) + comp_xless8 = self.less(x_min, 8) + comp_yless8 = self.less(y_max, 8) + temp = self.select(comp_yless8, log_beta_small, log_beta_one_large) + return self.select(comp_xless8, temp, log_beta_two_large) + + @constexpr def get_broadcast_matmul_shape(x_shape, y_shape): """get broadcast_matmul shape""" diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 710395fd44..385370ab7a 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -602,6 +602,11 @@ test_cases = [ 'block': nn.DiGamma(), 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], 'skip': ['backward']}), + ('LBeta', { + 'block': nn.LBeta(), + 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32)), + Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], + 'skip': ['backward']}), ('FlattenNet', { 'block': FlattenNet(), 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],