From 4aa836dae1329ddbf7d10151e4dc325af11360fd Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Tue, 13 Oct 2020 11:41:18 -0400 Subject: [PATCH] Add Digamma op --- mindspore/nn/layer/math.py | 98 +++++++++++++++++++++++++++++- tests/ut/python/ops/test_nn_ops.py | 4 ++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index d50aefa10d..b066d44b0b 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -25,7 +25,7 @@ from ...common import dtype as mstype from ..._checkparam import Validator as validator -__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'IGamma', 'MatMul', 'Moments'] +__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'MatMul', 'Moments'] @constexpr @@ -312,6 +312,102 @@ class LGamma(Cell): return self.select(self.isfinite(input_x), result, infinity) +class DiGamma(Cell): + r""" + Calculate Digamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function". + The algorithm is: + + .. math:: + digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) + + t(z) = z + kLanczosGamma + 1/2 + + A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k} + + A'(z) = \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{{z + k}^2} + + However, if the input is less than 0.5 use Euler's reflection formula: + + .. math:: + + digamma(x) = digamma(1 - x) - pi * cot(pi * x) + + Inputs: + - **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. + + Outputs: + Tensor, has the same shape and dtype as the `input_x`. + + Examples: + >>> input_x = Tensor(np.array([2, 3, 4]).astype(np.float32)) + >>> op = nn.DiGamma() + >>> output = op(input_x) + [0.42278463 0.92278427 1.2561178] + """ + + def __init__(self): + super(DiGamma, self).__init__() + # const numbers + self.k_lanczos_gamma = 7 + self.k_base_lanczos_coeff = 0.99999999999980993227684700473478 + self.k_lanczos_coefficients = [676.520368121885098567009190444019, + -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, + -176.61502916214059906584551354, + 12.507343278686904814458936853, + -0.13857109526572011689554707, + 9.984369578019570859563e-6, + 1.50563273514931155834e-7] + self.nan = np.nan + self.pi = np.pi + self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5 + self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half) + + # operations + self.log1p = P.Log1p() + self.abs = P.Abs() + self.shape = P.Shape() + self.dtype = P.DType() + self.fill = P.Fill() + self.floor = P.Floor() + self.equal = P.Equal() + self.less = P.Less() + self.select = P.Select() + self.sin = P.Sin() + self.cos = P.Cos() + self.logicaland = P.LogicalAnd() + + def construct(self, input_x): + input_dtype = self.dtype(input_x) + _check_input_dtype("input x", input_dtype, [mstype.float16, mstype.float32], self.cls_name) + need_to_reflect = self.less(input_x, 0.5) + neg_input = -input_x + z = self.select(need_to_reflect, neg_input, input_x - 1) + + @constexpr + def _calculate_num_denom(z, k_base_lanczos_coeff, k_lanczos_coefficients): + num = 0 + denom = k_base_lanczos_coeff + for i in range(8): + num = num - k_lanczos_coefficients[i] / ((z + i + 1) * (z + i + 1)) + denom = denom + k_lanczos_coefficients[i] / (z + i + 1) + return num, denom + num, denom = _calculate_num_denom(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients) + + t = z + self.lanczos_gamma_plus_one_half + log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half + + y = log_t + num / denom - self.k_lanczos_gamma / t + + reduced_input = input_x + self.abs(self.floor(input_x + 0.5)) + reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input) + real_result = self.select(need_to_reflect, reflection, y) + nan = self.fill(self.dtype(input_x), self.shape(input_x), np.nan) + + return self.select(self.logicaland(self.less(input_x, 0), self.equal(input_x, self.floor(input_x))), + nan, real_result) + + eps_fp16 = Tensor(np.finfo(np.float16).eps, mstype.float16) eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 7a5e062614..710395fd44 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -598,6 +598,10 @@ test_cases = [ 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32)), Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], 'skip': ['backward']}), + ('DiGamma', { + 'block': nn.DiGamma(), + 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], + 'skip': ['backward']}), ('FlattenNet', { 'block': FlattenNet(), 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],