From 5e9178c5b63e880e4fcfadd8436696cd3638056d Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 1 Oct 2020 11:08:37 -0400 Subject: [PATCH] Add IGamma operator --- mindspore/nn/layer/math.py | 272 ++++++++++++++++++++++++++++- tests/ut/python/ops/test_nn_ops.py | 5 + 2 files changed, 268 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 295acf8749..d50aefa10d 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -25,7 +25,12 @@ from ...common import dtype as mstype from ..._checkparam import Validator as validator -__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments'] +__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'IGamma', 'MatMul', 'Moments'] + + +@constexpr +def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name): + validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) class ReduceLogSumExp(Cell): @@ -43,7 +48,7 @@ class ReduceLogSumExp(Cell): Default : False. Inputs: - - **input_x** (Tensor[Number]) - The input tensor. With float16 or float32 data type. + - **input_x** (Tensor) - The input tensor. With float16 or float32 data type. Outputs: Tensor, has the same dtype as the `input_x`. @@ -213,7 +218,7 @@ class LGamma(Cell): when x = +/- inf, return +inf Inputs: - - **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. + - **input_x** (Tensor) - The input tensor. Only float16, float32 are supported. Outputs: Tensor, has the same shape and dtype as the `input_x`. @@ -267,7 +272,7 @@ class LGamma(Cell): def construct(self, input_x): input_dtype = self.dtype(input_x) - check_tensors_dtype_same(input_dtype, [mstype.float16, mstype.float32], "LGamma") + _check_input_dtype("input", input_dtype, [mstype.float16, mstype.float32], self.cls_name) infinity = self.fill(input_dtype, self.shape(input_x), self.inf) need_to_reflect = self.less(input_x, 0.5) @@ -307,6 +312,260 @@ class LGamma(Cell): return self.select(self.isfinite(input_x), result, infinity) +eps_fp16 = Tensor(np.finfo(np.float16).eps, mstype.float16) +eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) + +def _while_helper_func(cond, body, vals): + while cond(vals).any(): + vals = body(vals) + return vals + + +def _IgammaSeries(ax, x, a, enabled): + """Helper function for computing Igamma using a power series.""" + + logicaland = P.LogicalAnd() + greater = P.Greater() + fill = P.Fill() + shape = P.Shape() + dtype = P.DType() + select = P.Select() + + if dtype(ax) == mstype.float16: + epsilon = eps_fp16 + else: + epsilon = eps_fp32 + + def cond(vals): + enabled = vals[0] + return enabled + + def body(vals): + enabled = vals[0] + r = vals[1] + c = vals[2] + ans = vals[3] + x = vals[4] + dc_da = vals[5] + dans_da = vals[6] + + r = r + 1 + dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r) + dans_da = dans_da + dc_da + c = c * (x / r) + ans = ans + c + conditional = logicaland(enabled, greater(c / ans, epsilon)) + + return (conditional, select(enabled, r, vals[1]), + select(enabled, c, vals[2]), select(enabled, ans, vals[3]), + select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]), + select(enabled, dans_da, vals[6])) + + ones = fill(dtype(a), shape(a), 1) + zeros = fill(dtype(a), shape(a), 0) + vals = (enabled, a, ones, ones, x, zeros, zeros) + + vals = _while_helper_func(cond, body, vals) + ans = vals[3] + return (ans * ax) / a + + +def _IgammacContinuedFraction(ax, x, a, enabled): + """Helper function for computing Igammac using a continued fraction.""" + + abs_x = P.Abs() + logicaland = P.LogicalAnd() + greater = P.Greater() + less = P.Less() + notequal = P.NotEqual() + fill = P.Fill() + shape = P.Shape() + dtype = P.DType() + select = P.Select() + + if dtype(ax) == mstype.float16: + epsilon = eps_fp16 + else: + epsilon = eps_fp32 + + def cond(vals): + enabled = vals[0] + c = vals[5] + return logicaland(less(c, 2000), enabled) + + def body(vals): + enabled = vals[0] + ans = vals[1] + t = vals[2] + y = vals[3] + z = vals[4] + c = vals[5] + pkm1 = vals[6] + qkm1 = vals[7] + pkm2 = vals[8] + qkm2 = vals[9] + + dpkm2_da = vals[10] + dqkm2_da = vals[11] + dpkm1_da = vals[12] + dqkm1_da = vals[13] + dans_da = vals[14] + + c = c + 1 + y = y + 1 + z = z + 2 + + yc = y * c + pk = pkm1 * z - pkm2 * yc + qk = qkm1 * z - qkm2 * yc + qk_is_nonzero = notequal(qk, 0) + r = pk / qk + + t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1)) + ans = select(qk_is_nonzero, r, ans) + + dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c + dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c + dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da) + grad_conditional = select(qk_is_nonzero, + abs_x(dans_da_new - dans_da), + fill(dtype(dans_da), shape(dans_da), 1)) + + pkm2 = pkm1 + pkm1 = pk + qkm2 = qkm1 + qkm1 = qk + + dpkm2_da = dpkm1_da + dqkm2_da = dqkm1_da + dpkm1_da = dpk_da + dqkm1_da = dqk_da + + rescale = greater(abs_x(pk), 1 / epsilon) + pkm2 = select(rescale, pkm2 * epsilon, pkm2) + pkm1 = select(rescale, pkm1 * epsilon, pkm1) + qkm2 = select(rescale, qkm2 * epsilon, qkm2) + qkm1 = select(rescale, qkm1 * epsilon, qkm1) + + dpkm2_da = select(rescale, dpkm2_da * epsilon, dpkm2_da) + dqkm2_da = select(rescale, dqkm2_da * epsilon, dqkm2_da) + dpkm1_da = select(rescale, dpkm1_da * epsilon, dpkm1_da) + dqkm1_da = select(rescale, dqkm1_da * epsilon, dqkm1_da) + + conditional = logicaland(enabled, greater(grad_conditional, epsilon)) + + return (conditional, select(enabled, ans, vals[1]), select(enabled, t, vals[2]), + select(enabled, y, vals[3]), select(enabled, z, vals[4]), + c, select(enabled, pkm1, vals[6]), + select(enabled, qkm1, vals[7]), select(enabled, pkm2, vals[8]), + select(enabled, qkm2, vals[9]), select(enabled, dpkm2_da, vals[10]), + select(enabled, dqkm2_da, vals[11]), select(enabled, dpkm1_da, vals[12]), + select(enabled, dqkm1_da, vals[13]), select(enabled, dans_da_new, vals[14])) + + y = 1 - a + z = x + y + 1 + c = fill(dtype(x), shape(x), 0) + pkm2 = fill(dtype(x), shape(x), 1) + qkm2 = x + pkm1 = x + 1 + qkm1 = z * x + ans = pkm1 / qkm1 + t = fill(dtype(x), shape(x), 1) + dpkm2_da = fill(dtype(x), shape(x), 0) + dqkm2_da = fill(dtype(x), shape(x), 0) + dpkm1_da = fill(dtype(x), shape(x), 0) + dqkm1_da = -x + dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1 + vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da) + vals = _while_helper_func(cond, body, vals) + ans = vals[1] + return ans * ax + + +class IGamma(Cell): + r""" + Calculate lower regularized incomplete Gamma function. + The lower regularized incomplete Gamma function is defined as: + + .. math:: + P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x) + + where + + .. math:: + gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt + + is the lower incomplete Gamma function. + + Above :math:`Q(a, x)` is the upper regularized complete Gamma function. + + Inputs: + - **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have + the same dtype with `x`. + - **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have + the same dtype with `a`. + + Outputs: + Tensor, has the same dtype as `a` and `x`. + + Examples: + >>> input_a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) + >>> input_x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)) + >>> igamma = nn.IGamma() + >>> output = igamma(input_a, input_x) + >>> print (output) + [0.593994 0.35276785 0.21486944 0.13337152] + """ + + def __init__(self): + super(IGamma, self).__init__() + # const numbers + self.log_maxfloat16 = Tensor(np.log(np.finfo(np.float16).max), mstype.float16) + self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) + + # operations + self.logicaland = P.LogicalAnd() + self.logicalor = P.LogicalOr() + self.logicalnot = P.LogicalNot() + self.equal = P.Equal() + self.greater = P.Greater() + self.less = P.Less() + self.neg = P.Neg() + self.log = P.Log() + self.exp = P.Exp() + self.select = P.Select() + self.zeroslike = P.ZerosLike() + self.fill = P.Fill() + self.shape = P.Shape() + self.dtype = P.DType() + self.lgamma = LGamma() + self.const = P.ScalarToArray() + self.cast = P.Cast() + + def construct(self, a, x): + a_dtype = self.dtype(a) + x_dtype = self.dtype(x) + _check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name) + _check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) + x_is_zero = self.equal(x, 0) + domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) + use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) + ax = a * self.log(x) - x - self.lgamma(a) + if a_dtype == mstype.float16: + log_maxfloat = self.log_maxfloat16 + else: + log_maxfloat = self.log_maxfloat32 + underflow = self.less(ax, self.neg(log_maxfloat)) + ax = self.exp(ax) + enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow)) + output = self.select(use_igammac, + 1 - _IgammacContinuedFraction(ax, x, a, self.logicaland(enabled, use_igammac)), + _IgammaSeries(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac)))) + output = self.select(x_is_zero, self.zeroslike(output), output) + output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output) + return output + + @constexpr def get_broadcast_matmul_shape(x_shape, y_shape): """get broadcast_matmul shape""" @@ -453,11 +712,6 @@ class MatMul(Cell): return matmul_broadcast -@constexpr -def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name): - validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) - - class Moments(Cell): """ Calculate the mean and variance of `x`. diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index abf0b034d3..7a5e062614 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -593,6 +593,11 @@ test_cases = [ 'block': nn.LGamma(), 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], 'skip': ['backward']}), + ('IGamma', { + 'block': nn.IGamma(), + 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32)), + Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], + 'skip': ['backward']}), ('FlattenNet', { 'block': FlattenNet(), 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],