diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index 1d9d2c5a88..ca877852a8 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -16,6 +16,7 @@ from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel +from ..distribution._utils.utils import CheckTensor from .bijector import Bijector class PowerTransform(Bijector): @@ -62,6 +63,8 @@ class PowerTransform(Bijector): self.log1p = self._log1p_by_step self.expm1 = self._expm1_by_step + self.checktensor = CheckTensor() + def _log1p_by_step(self, x): """ Log1p ops on GPU device or when device_target == GPU. @@ -86,11 +89,13 @@ class PowerTransform(Bijector): return shape def _forward(self, x): + self.checktensor(x, 'x') if self.power == 0: return self.exp(x) return self.exp(self.log1p(x * self.power) / self.power) def _inverse(self, y): + self.checktensor(y, 'y') if self.power == 0: return self.log(y) return self.expm1(self.log(y) * self.power) / self.power @@ -107,6 +112,7 @@ class PowerTransform(Bijector): f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) """ + self.checktensor(x, 'x') if self.power == 0: return x return (1. / self.power - 1) * self.log1p(x * self.power) @@ -123,4 +129,5 @@ class PowerTransform(Bijector): f'(x) = \frac{e^c\log(y)}{y} \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) """ + self.checktensor(y, 'y') return (self.power - 1) * self.log(y) diff --git a/mindspore/nn/probability/bijector/scalar_affine.py b/mindspore/nn/probability/bijector/scalar_affine.py index b48df1f0a7..44de3c68a0 100644 --- a/mindspore/nn/probability/bijector/scalar_affine.py +++ b/mindspore/nn/probability/bijector/scalar_affine.py @@ -15,7 +15,7 @@ """Scalar Affine Bijector""" from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import cast_to_tensor +from ..distribution._utils.utils import cast_to_tensor, CheckTensor from .bijector import Bijector class ScalarAffine(Bijector): @@ -54,8 +54,8 @@ class ScalarAffine(Bijector): Constructor of scalar affine bijector. """ param = dict(locals()) - validator.check_value_type('scale', scale, [float], name) - validator.check_value_type('shift', shift, [float], name) + validator.check_value_type('scale', scale, [int, float], name) + validator.check_value_type('shift', shift, [int, float], name) self._scale = cast_to_tensor(scale) self._shift = cast_to_tensor(shift) super(ScalarAffine, self).__init__( @@ -65,8 +65,10 @@ class ScalarAffine(Bijector): dtype=None, param=param) + self.abs = P.Abs() self.log = P.Log() - self.oneslike = P.OnesLike() + + self.checktensor = CheckTensor() @property def scale(self): @@ -88,6 +90,7 @@ class ScalarAffine(Bijector): .. math:: f(x) = a * x + b """ + self.checktensor(x, 'x') return self.scale * x + self.shift def _inverse(self, y): @@ -95,22 +98,25 @@ class ScalarAffine(Bijector): .. math:: f(y) = \frac{y - b}{a} """ + self.checktensor(y, 'y') return (y - self.shift) / self.scale - def _forward_log_jacobian(self, value): + def _forward_log_jacobian(self, x): r""" .. math:: f(x) = a * x + b f'(x) = a \log(f'(x)) = \log(a) """ - return self.log(self.scale) * self.oneslike(value) + self.checktensor(x, 'x') + return self.log(self.abs(self.scale)) - def _inverse_log_jacobian(self, value): + def _inverse_log_jacobian(self, y): r""" .. math:: f(y) = \frac{(y - b)}{a} f'(x) = \frac{1.0}{a} \log(f'(x)) = - \log(a) """ - return -1. * self.log(self.scale) * self.oneslike(value) + self.checktensor(y, 'y') + return -1. * self.log(self.abs(self.scale)) diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index ee86aa2133..070e483707 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -13,10 +13,12 @@ # limitations under the License. # ============================================================================ """Softplus Bijector""" +import numpy as np from mindspore.ops import operations as P +from mindspore.common import dtype as mstype from mindspore.nn.layer.activation import LogSigmoid from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import cast_to_tensor +from ..distribution._utils.utils import cast_to_tensor, CheckTensor from .bijector import Bijector class Softplus(Bijector): @@ -52,19 +54,28 @@ class Softplus(Bijector): sharpness=1.0, name='Softplus'): param = dict(locals()) - validator.check_value_type('sharpness', sharpness, [float], name) + validator.check_value_type('sharpness', sharpness, [int, float], name) super(Softplus, self).__init__(name=name, param=param) self._sharpness = cast_to_tensor(sharpness) + self.abs = P.Abs() self.exp = P.Exp() self.expm1 = self._expm1_by_step + self.fill = P.Fill() + self.greater = P.Greater() + self.less = P.Less() self.log_sigmoid = LogSigmoid() self.log = P.Log() + self.logicalor = P.LogicalOr() + self.select = P.Select() + self.shape = P.Shape() self.sigmoid = P.Sigmoid() - self.softplus = self._softplus self.inverse_softplus = self._inverse_softplus + self.checktensor = CheckTensor() + self.threshold = np.log(np.finfo(np.float32).eps) + 1 + def _expm1_by_step(self, x): """ Expm1 ops under GPU context. @@ -72,7 +83,15 @@ class Softplus(Bijector): return self.exp(x) - 1.0 def _softplus(self, x): - return self.log(self.exp(x) + 1.0) + too_small = self.less(x, self.threshold) + too_large = self.greater(x, -self.threshold) + too_small_value = self.exp(x) + too_large_value = x + ones = self.fill(mstype.float32, self.shape(x), 1.0) + too_small_or_too_large = self.logicalor(too_small, too_large) + x = self.select(too_small_or_too_large, ones, x) + y = self.log(self.exp(x) + 1.0) + return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) def _inverse_softplus(self, x): r""" @@ -80,7 +99,15 @@ class Softplus(Bijector): f(x) = \frac{\log(1 + e^{x}))} f^{-1}(y) = \frac{\log(e^{y} - 1)} """ - return self.log(self.expm1(x)) + too_small = self.less(x, self.threshold) + too_large = self.greater(x, -self.threshold) + too_small_value = self.log(x) + too_large_value = x + ones = self.fill(mstype.float32, self.shape(x), 1.0) + too_small_or_too_large = self.logicalor(too_small, too_large) + x = self.select(too_small_or_too_large, ones, x) + y = x + self.log(self.abs(self.expm1(-x))) + return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) @property def sharpness(self): @@ -94,6 +121,7 @@ class Softplus(Bijector): return shape def _forward(self, x): + self.checktensor(x, 'x') scaled_value = self.sharpness * x return self.softplus(scaled_value) / self.sharpness @@ -103,6 +131,7 @@ class Softplus(Bijector): f(x) = \frac{\log(1 + e^{kx}))}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} """ + self.checktensor(y, 'y') scaled_value = self.sharpness * y return self.inverse_softplus(scaled_value) / self.sharpness @@ -113,6 +142,7 @@ class Softplus(Bijector): f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) """ + self.checktensor(x, 'x') scaled_value = self.sharpness * x return self.log_sigmoid(scaled_value) @@ -123,5 +153,6 @@ class Softplus(Bijector): f'(y) = \frac{e^{ky}}{e^{ky} - 1} \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) """ + self.checktensor(y, 'y') scaled_value = self.sharpness * y return scaled_value - self.inverse_softplus(scaled_value) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 58d1c7cd01..e6339d0e08 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -15,7 +15,8 @@ """Utitly functions to help distribution class.""" import numpy as np from mindspore.ops import _utils as utils -from mindspore.ops.primitive import constexpr +from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register +from mindspore._checkparam import Validator as validator from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype @@ -53,7 +54,9 @@ def cast_to_tensor(t, hint_type=mstype.float32): raise TypeError(f'Input cannot be Type Bool') if isinstance(t, (int, float)): return Tensor(t, dtype=t_type) - raise TypeError("Input type is not supported.") + invalid_type = type(t) + raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") + def convert_to_batch(t, batch_shape, required_type): """ @@ -274,5 +277,51 @@ def raise_none_error(name): @constexpr def check_distribution_name(name, expected_name): + if name is None: + raise ValueError(f"Distribution should be a constant which is not None.") if name != expected_name: - raise ValueError(f"Distribution should be {expected_name}.") + raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.") + +class CheckTuple(PrimitiveWithInfer): + """ + Check if input is a tuple. + """ + @prim_attr_register + def __init__(self): + """init Cast""" + super(CheckTuple, self).__init__("CheckTuple") + self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) + + def __infer__(self, x, name): + if not isinstance(x['dtype'], tuple): + raise TypeError("Input type should be a tuple: " + name["value"]) + + out = {'shape': None, + 'dtype': None, + 'value': None} + return out + + def __call__(self, *args): + return + +class CheckTensor(PrimitiveWithInfer): + """ + Check if input is a Tensor. + """ + @prim_attr_register + def __init__(self): + """init Cast""" + super(CheckTensor, self).__init__("CheckTensor") + self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) + + def __infer__(self, x, name): + src_type = x['dtype'] + validator.check_subclass("input", src_type, [mstype.tensor], name["value"]) + + out = {'shape': None, + 'dtype': None, + 'value': None} + return out + + def __call__(self, *args): + return diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 509c6fe8e7..e0170ffe4e 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -18,6 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error +from ._utils.utils import CheckTensor, CheckTuple class Bernoulli(Distribution): """ @@ -123,6 +124,9 @@ class Bernoulli(Distribution): self.sqrt = P.Sqrt() self.uniform = C.uniform + self.checktensor = CheckTensor() + self.checktuple = CheckTuple() + def extend_repr(self): if self.is_scalar_batch: str_info = f'probs = {self.probs}' @@ -137,14 +141,21 @@ class Bernoulli(Distribution): """ return self._probs + def _check_param(self, probs1): + """ + Check availablity of distribution specific args probs1. + """ + if probs1 is not None: + self.checktensor(probs1, 'probs1') + return self.cast(probs1, self.parameter_type) + return self.probs if self.probs is not None else raise_none_error('probs1') + def _mean(self, probs1=None): r""" .. math:: MEAN(B) = probs1 """ - probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs - if probs1 is None: - raise_none_error("probs1") + probs1 = self._check_param(probs1) return probs1 def _mode(self, probs1=None): @@ -152,9 +163,7 @@ class Bernoulli(Distribution): .. math:: MODE(B) = 1 if probs1 > 0.5 else = 0 """ - probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs - if probs1 is None: - raise_none_error("probs1") + probs1 = self._check_param(probs1) prob_type = self.dtypeop(probs1) zeros = self.fill(prob_type, self.shape(probs1), 0.0) ones = self.fill(prob_type, self.shape(probs1), 1.0) @@ -166,24 +175,20 @@ class Bernoulli(Distribution): .. math:: VAR(B) = probs1 * probs0 """ - probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs - if probs1 is None: - raise_none_error("probs1") + probs1 = self._check_param(probs1) probs0 = 1.0 - probs1 return self.exp(self.log(probs0) + self.log(probs1)) - def _entropy(self, probs=None): + def _entropy(self, probs1=None): r""" .. math:: H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) """ - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + probs1 = self._check_param(probs1) probs0 = 1 - probs1 return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) - def _cross_entropy(self, dist, probs1_b, probs1_a=None): + def _cross_entropy(self, dist, probs1_b, probs1=None): """ Evaluate cross_entropy between Bernoulli distributions. @@ -193,9 +198,9 @@ class Bernoulli(Distribution): probs1_a (Tensor): probs1 of distribution a. Default: self.probs. """ check_distribution_name(dist, 'Bernoulli') - return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) + return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) - def _log_prob(self, value, probs=None): + def _log_prob(self, value, probs1=None): r""" pmf of Bernoulli distribution. @@ -207,17 +212,14 @@ class Bernoulli(Distribution): pmf(k) = probs1 if k = 1; pmf(k) = probs0 if k = 0; """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + probs1 = self._check_param(probs1) probs0 = 1.0 - probs1 return self.log(probs1) * value + self.log(probs0) * (1.0 - value) - def _cdf(self, value, probs=None): + def _cdf(self, value, probs1=None): r""" cdf of Bernoulli distribution. @@ -230,13 +232,10 @@ class Bernoulli(Distribution): cdf(k) = probs0 if 0 <= k <1; cdf(k) = 1 if k >=1; """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + probs1 = self._check_param(probs1) prob_type = self.dtypeop(probs1) value = value * self.fill(prob_type, self.shape(probs1), 1.0) probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) @@ -247,7 +246,7 @@ class Bernoulli(Distribution): less_than_zero = self.select(comp_zero, zeros, probs0) return self.select(comp_one, less_than_zero, ones) - def _kl_loss(self, dist, probs1_b, probs1_a=None): + def _kl_loss(self, dist, probs1_b, probs1=None): r""" Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). @@ -261,17 +260,14 @@ class Bernoulli(Distribution): probs0_a * \log(\frac{probs0_a}{probs0_b}) """ check_distribution_name(dist, 'Bernoulli') - if probs1_b is None: - raise_none_error("probs1_b") + self.checktensor(probs1_b, 'probs1_b') probs1_b = self.cast(probs1_b, self.parameter_type) - probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs - if probs1_a is None: - raise_none_error("probs1_a") + probs1_a = self._check_param(probs1) probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) - def _sample(self, shape=(), probs=None): + def _sample(self, shape=(), probs1=None): """ Sampling. @@ -282,9 +278,8 @@ class Bernoulli(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + self.checktuple(shape, 'shape') + probs1 = self._check_param(probs1) origin_shape = shape + self.shape(probs1) if origin_shape == (): sample_shape = (1,) diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index b89b8af627..a8f38df16f 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ raise_none_error +from ._utils.utils import CheckTensor, CheckTuple class Exponential(Distribution): """ @@ -125,6 +126,9 @@ class Exponential(Distribution): self.sq = P.Square() self.uniform = C.uniform + self.checktensor = CheckTensor() + self.checktuple = CheckTuple() + def extend_repr(self): if self.is_scalar_batch: str_info = f'rate = {self.rate}' @@ -139,14 +143,21 @@ class Exponential(Distribution): """ return self._rate + def _check_param(self, rate): + """ + Check availablity of distribution specific args rate. + """ + if rate is not None: + self.checktensor(rate, 'rate') + return self.cast(rate, self.parameter_type) + return self.rate if self.rate is not None else raise_none_error('rate') + def _mean(self, rate=None): r""" .. math:: MEAN(EXP) = \frac{1.0}{\lambda}. """ - rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate - if rate is None: - raise_none_error("rate") + rate = self._check_param(rate) return 1.0 / rate def _mode(self, rate=None): @@ -154,9 +165,7 @@ class Exponential(Distribution): .. math:: MODE(EXP) = 0. """ - rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate - if rate is None: - raise_none_error("rate") + rate = self._check_param(rate) return self.fill(self.dtype, self.shape(rate), 0.) def _sd(self, rate=None): @@ -164,9 +173,7 @@ class Exponential(Distribution): .. math:: sd(EXP) = \frac{1.0}{\lambda}. """ - rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate - if rate is None: - raise_none_error("rate") + rate = self._check_param(rate) return 1.0 / rate def _entropy(self, rate=None): @@ -174,13 +181,10 @@ class Exponential(Distribution): .. math:: H(Exp) = 1 - \log(\lambda). """ - rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate - if rate is None: - raise_none_error("rate") + rate = self._check_param(rate) return 1.0 - self.log(rate) - - def _cross_entropy(self, dist, rate_b, rate_a=None): + def _cross_entropy(self, dist, rate_b, rate=None): """ Evaluate cross_entropy between Exponential distributions. @@ -190,7 +194,7 @@ class Exponential(Distribution): rate_a (Tensor): rate of distribution a. Default: self.rate. """ check_distribution_name(dist, 'Exponential') - return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a) + return self._entropy(rate) + self._kl_loss(dist, rate_b, rate) def _prob(self, value, rate=None): @@ -208,12 +212,9 @@ class Exponential(Distribution): .. math:: pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 """ - if value is None: - raise_none_error("value") + self.checktensor(value, "value") value = self.cast(value, self.dtype) - rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate - if rate is None: - raise_none_error("rate") + rate = self._check_param(rate) prob = self.exp(self.log(rate) - rate * value) zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) comp = self.less(value, zeros) @@ -233,19 +234,16 @@ class Exponential(Distribution): .. math:: cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, self.dtype) - rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate - if rate is None: - raise_none_error("rate") + rate = self._check_param(rate) cdf = 1.0 - self.exp(-1. * rate * value) zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) comp = self.less(value, zeros) return self.select(comp, zeros, cdf) - def _kl_loss(self, dist, rate_b, rate_a=None): + def _kl_loss(self, dist, rate_b, rate=None): """ Evaluate exp-exp kl divergence, i.e. KL(a||b). @@ -255,12 +253,9 @@ class Exponential(Distribution): rate_a (Tensor): rate of distribution a. Default: self.rate. """ check_distribution_name(dist, 'Exponential') - if rate_b is None: - raise_none_error("rate_b") + self.checktensor(rate_b, 'rate_b') rate_b = self.cast(rate_b, self.parameter_type) - rate_a = self.cast(rate_a, self.parameter_type) if rate_a is not None else self.rate - if rate_a is None: - raise_none_error("rate_a") + rate_a = self._check_param(rate) return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 def _sample(self, shape=(), rate=None): @@ -274,9 +269,8 @@ class Exponential(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate - if rate is None: - raise_none_error("rate") + self.checktuple(shape, 'shape') + rate = self._check_param(rate) origin_shape = shape + self.shape(rate) if origin_shape == (): sample_shape = (1,) diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 45acecfe86..1f16ef0240 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ raise_none_error +from ._utils.utils import CheckTensor, CheckTuple class Geometric(Distribution): """ @@ -129,6 +130,9 @@ class Geometric(Distribution): self.sqrt = P.Sqrt() self.uniform = C.uniform + self.checktensor = CheckTensor() + self.checktuple = CheckTuple() + def extend_repr(self): if self.is_scalar_batch: str_info = f'probs = {self.probs}' @@ -143,14 +147,21 @@ class Geometric(Distribution): """ return self._probs + def _check_param(self, probs1): + """ + Check availablity of distribution specific args probs1. + """ + if probs1 is not None: + self.checktensor(probs1, 'probs1') + return self.cast(probs1, self.parameter_type) + return self.probs if self.probs is not None else raise_none_error('probs1') + def _mean(self, probs1=None): r""" .. math:: MEAN(Geo) = \fratc{1 - probs1}{probs1} """ - probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs - if probs1 is None: - raise_none_error("probs1") + probs1 = self._check_param(probs1) return (1. - probs1) / probs1 def _mode(self, probs1=None): @@ -158,9 +169,7 @@ class Geometric(Distribution): .. math:: MODE(Geo) = 0 """ - probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs - if probs1 is None: - raise_none_error("probs1") + probs1 = self._check_param(probs1) return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) def _var(self, probs1=None): @@ -168,23 +177,19 @@ class Geometric(Distribution): .. math:: VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} """ - probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs - if probs1 is None: - raise_none_error("probs1") + probs1 = self._check_param(probs1) return (1.0 - probs1) / self.sq(probs1) - def _entropy(self, probs=None): + def _entropy(self, probs1=None): r""" .. math:: H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} """ - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + probs1 = self._check_param(probs1) probs0 = 1.0 - probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 - def _cross_entropy(self, dist, probs1_b, probs1_a=None): + def _cross_entropy(self, dist, probs1_b, probs1=None): r""" Evaluate cross_entropy between Geometric distributions. @@ -194,9 +199,9 @@ class Geometric(Distribution): probs1_a (Tensor): probability of success of distribution a. Default: self.probs. """ check_distribution_name(dist, 'Geometric') - return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) + return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) - def _prob(self, value, probs=None): + def _prob(self, value, probs1=None): r""" pmf of Geometric distribution. @@ -208,19 +213,16 @@ class Geometric(Distribution): pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = 0 if k < 0. """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + probs1 = self._check_param(probs1) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) comp = self.less(value, zeros) return self.select(comp, zeros, pmf) - def _cdf(self, value, probs=None): + def _cdf(self, value, probs1=None): r""" cdf of Geometric distribution. @@ -233,13 +235,10 @@ class Geometric(Distribution): cdf(k) = 0 if k < 0. """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + probs1 = self._check_param(probs1) probs0 = 1.0 - probs1 cdf = 1.0 - self.pow(probs0, value + 1.0) zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) @@ -247,7 +246,7 @@ class Geometric(Distribution): return self.select(comp, zeros, cdf) - def _kl_loss(self, dist, probs1_b, probs1_a=None): + def _kl_loss(self, dist, probs1_b, probs1=None): r""" Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). @@ -260,17 +259,14 @@ class Geometric(Distribution): KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) """ check_distribution_name(dist, 'Geometric') - if probs1_b is None: - raise_none_error("probs1_b") + self.checktensor(probs1_b, 'probs1_b') probs1_b = self.cast(probs1_b, self.parameter_type) - probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs - if probs1_a is None: - raise_none_error("probs1_a") + probs1_a = self._check_param(probs1) probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) - def _sample(self, shape=(), probs=None): + def _sample(self, shape=(), probs1=None): """ Sampling. @@ -281,9 +277,8 @@ class Geometric(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs - if probs1 is None: - raise_none_error("probs") + self.checktuple(shape, 'shape') + probs1 = self._check_param(probs1) origin_shape = shape + self.shape(probs1) if origin_shape == (): sample_shape = (1,) diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 86c867696f..ec72d5ea78 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ raise_none_error +from ._utils.utils import CheckTensor, CheckTuple class Normal(Distribution): """ @@ -112,7 +113,6 @@ class Normal(Distribution): self._mean_value = mean self._sd_value = sd - #ops needed for the class self.squeeze = P.Squeeze(0) self.cast = P.Cast() @@ -127,6 +127,9 @@ class Normal(Distribution): self.sqrt = P.Sqrt() self.zeroslike = P.ZerosLike() + self.checktensor = CheckTensor() + self.checktuple = CheckTuple() + def extend_repr(self): if self.is_scalar_batch: str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' @@ -140,40 +143,44 @@ class Normal(Distribution): """ return self.exp(x) - 1.0 + def _check_param(self, mean, sd): + """ + Check availablity of distribution specific args mean and sd. + """ + if mean is not None: + self.checktensor(mean, 'mean') + mean = self.cast(mean, self.parameter_type) + else: + mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') + if sd is not None: + self.checktensor(sd, 'sd') + sd = self.cast(sd, self.parameter_type) + else: + sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') + batch_shape = self.shape(mean + sd) + mean = mean * self.fill(self.dtype, batch_shape, 1.0) + sd = sd * self.fill(self.dtype, batch_shape, 1.0) + return mean, sd + def _mean(self, mean=None, sd=None): """ Mean of the distribution. """ - mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value - if mean is None: - raise_none_error("mean") - sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value - if sd is None: - raise_none_error("sd") + mean, sd = self._check_param(mean, sd) return mean def _mode(self, mean=None, sd=None): """ Mode of the distribution. """ - mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value - if mean is None: - raise_none_error("mean") - sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value - if sd is None: - raise_none_error("sd") + mean, sd = self._check_param(mean, sd) return mean def _sd(self, mean=None, sd=None): """ Standard deviation of the distribution. """ - mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value - if mean is None: - raise_none_error("mean") - sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value - if sd is None: - raise_none_error("sd") + mean, sd = self._check_param(mean, sd) return sd def _entropy(self, mean=None, sd=None): @@ -183,15 +190,10 @@ class Normal(Distribution): .. math:: H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) """ - mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value - if mean is None: - raise_none_error("mean") - sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value - if sd is None: - raise_none_error("sd") + mean, sd = self._check_param(mean, sd) return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) - def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): + def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): r""" Evaluate cross_entropy between normal distributions. @@ -203,7 +205,7 @@ class Normal(Distribution): sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. """ check_distribution_name(dist, 'Normal') - return self._entropy(mean=mean_a, sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a) + return self._entropy(mean, sd) + self._kl_loss(dist, mean_b, sd_b, mean, sd) def _log_prob(self, value, mean=None, sd=None): r""" @@ -217,15 +219,9 @@ class Normal(Distribution): .. math:: L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, self.dtype) - mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value - if mean is None: - raise_none_error("mean") - sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value - if sd is None: - raise_none_error("sd") + mean, sd = self._check_param(mean, sd) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) return unnormalized_log_prob + neg_normalization @@ -242,20 +238,14 @@ class Normal(Distribution): .. math:: cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, self.dtype) - mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value - if mean is None: - raise_none_error("mean") - sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value - if sd is None: - raise_none_error("sd") + mean, sd = self._check_param(mean, sd) sqrt2 = self.sqrt(self.const(2.0)) adjusted = (value - mean) / (sd * sqrt2) return 0.5 * (1.0 + self.erf(adjusted)) - def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): + def _kl_loss(self, dist, mean_b, sd_b, mean=None, sd=None): r""" Evaluate Normal-Normal kl divergence, i.e. KL(a||b). @@ -271,23 +261,15 @@ class Normal(Distribution): 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) """ check_distribution_name(dist, 'Normal') - if mean_b is None: - raise_none_error("mean_b") - if sd_b is None: - raise_none_error("sd_b") + self.checktensor(mean_b, 'mean_b') + self.checktensor(sd_b, 'sd_b') mean_b = self.cast(mean_b, self.parameter_type) sd_b = self.cast(sd_b, self.parameter_type) - mean_a = self.cast(mean_a, self.parameter_type) if mean_a is not None else self._mean_value - sd_a = self.cast(sd_a, self.parameter_type) if sd_a is not None else self._sd_value - if mean_a is None: - raise_none_error("mean_a") - if sd_a is None: - raise_none_error("sd_a") + mean_a, sd_a = self._check_param(mean, sd) diff_log_scale = self.log(sd_a) - self.log(sd_b) squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale - def _sample(self, shape=(), mean=None, sd=None): """ Sampling. @@ -300,12 +282,8 @@ class Normal(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value - if mean is None: - raise_none_error("mean") - sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value - if sd is None: - raise_none_error("sd") + self.checktuple(shape, 'shape') + mean, sd = self._check_param(mean, sd) batch_shape = self.shape(mean + sd) origin_shape = shape + batch_shape if origin_shape == (): diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 2d1324804f..a37162a507 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ raise_none_error +from ._utils.utils import CheckTensor, CheckTuple class Uniform(Distribution): """ @@ -129,6 +130,9 @@ class Uniform(Distribution): self.zeroslike = P.ZerosLike() self.uniform = C.uniform + self.checktensor = CheckTensor() + self.checktuple = CheckTuple() + def extend_repr(self): if self.is_scalar_batch: str_info = f'low = {self.low}, high = {self.high}' @@ -136,6 +140,25 @@ class Uniform(Distribution): str_info = f'batch_shape = {self._broadcast_shape}' return str_info + def _check_param(self, low, high): + """ + Check availablity of distribution specific args low and high. + """ + if low is not None: + self.checktensor(low, 'low') + low = self.cast(low, self.parameter_type) + else: + low = self.low if self.low is not None else raise_none_error('low') + if high is not None: + self.checktensor(high, 'high') + high = self.cast(high, self.parameter_type) + else: + high = self.high if self.high is not None else raise_none_error('high') + batch_shape = self.shape(high - low) + high = high * self.fill(self.dtype, batch_shape, 1.0) + low = low * self.fill(self.dtype, batch_shape, 1.0) + return low, high + @property def low(self): """ @@ -156,12 +179,7 @@ class Uniform(Distribution): .. math:: range(U) = high -low """ - low = self.cast(low, self.parameter_type) if low is not None else self.low - if low is None: - raise_none_error("low") - high = self.cast(high, self.parameter_type) if high is not None else self.high - if high is None: - raise_none_error("high") + low, high = self._check_param(low, high) return high - low def _mean(self, low=None, high=None): @@ -169,12 +187,7 @@ class Uniform(Distribution): .. math:: MEAN(U) = \frac{low + high}{2}. """ - low = self.cast(low, self.parameter_type) if low is not None else self.low - if low is None: - raise_none_error("low") - high = self.cast(high, self.parameter_type) if high is not None else self.high - if high is None: - raise_none_error("high") + low, high = self._check_param(low, high) return (low + high) / 2. def _var(self, low=None, high=None): @@ -182,12 +195,7 @@ class Uniform(Distribution): .. math:: VAR(U) = \frac{(high -low) ^ 2}{12}. """ - low = self.cast(low, self.parameter_type) if low is not None else self.low - if low is None: - raise_none_error("low") - high = self.cast(high, self.parameter_type) if high is not None else self.high - if high is None: - raise_none_error("high") + low, high = self._check_param(low, high) return self.sq(high - low) / 12.0 def _entropy(self, low=None, high=None): @@ -195,15 +203,10 @@ class Uniform(Distribution): .. math:: H(U) = \log(high - low). """ - low = self.cast(low, self.parameter_type) if low is not None else self.low - if low is None: - raise_none_error("low") - high = self.cast(high, self.parameter_type) if high is not None else self.high - if high is None: - raise_none_error("high") + low, high = self._check_param(low, high) return self.log(high - low) - def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None): + def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): """ Evaluate cross_entropy between Uniform distributoins. @@ -215,7 +218,7 @@ class Uniform(Distribution): high_a (Tensor): upper bound of distribution a. Default: self.high. """ check_distribution_name(dist, 'Uniform') - return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a) + return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high) def _prob(self, value, low=None, high=None): r""" @@ -231,15 +234,9 @@ class Uniform(Distribution): pdf(x) = \frac{1.0}{high -low} if low <= x <= high; pdf(x) = 0 if x > high; """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, self.dtype) - low = self.cast(low, self.parameter_type) if low is not None else self.low - if low is None: - raise_none_error("low") - high = self.cast(high, self.parameter_type) if high is not None else self.high - if high is None: - raise_none_error("high") + low, high = self._check_param(low, high) neg_ones = self.fill(self.dtype, self.shape(value), -1.0) prob = self.exp(neg_ones * self.log(high - low)) broadcast_shape = self.shape(prob) @@ -249,7 +246,7 @@ class Uniform(Distribution): less_than_low = self.select(comp_lo, zeros, prob) return self.select(comp_hi, less_than_low, zeros) - def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None): + def _kl_loss(self, dist, low_b, high_b, low=None, high=None): """ Evaluate uniform-uniform kl divergence, i.e. KL(a||b). @@ -261,19 +258,12 @@ class Uniform(Distribution): high_a (Tensor): upper bound of distribution a. Default: self.high. """ check_distribution_name(dist, 'Uniform') - if low_b is None: - raise_none_error("low_b") - if high_b is None: - raise_none_error("high_b") + self.checktensor(low_b, 'low_b') low_b = self.cast(low_b, self.parameter_type) + self.checktensor(high_b, 'high_b') high_b = self.cast(high_b, self.parameter_type) - low_a = self.cast(low_a, self.parameter_type) if low_a is not None else self.low - if low_a is None: - raise_none_error("low_a") - high_a = self.cast(high_a, self.parameter_type) if high_a is not None else self.high - if high_a is None: - raise_none_error("high_a") - kl = self.log(high_b - low_b) / self.log(high_a - low_a) + low_a, high_a = self._check_param(low, high) + kl = self.log(high_b - low_b) - self.log(high_a - low_a) comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) return self.select(comp, kl, self.log(self.zeroslike(kl))) @@ -291,15 +281,9 @@ class Uniform(Distribution): cdf(x) = \frac{x - low}{high -low} if low <= x <= high; cdf(x) = 1 if x > high; """ - if value is None: - raise_none_error("value") + self.checktensor(value, 'value') value = self.cast(value, self.dtype) - low = self.cast(low, self.parameter_type) if low is not None else self.low - if low is None: - raise_none_error("low") - high = self.cast(high, self.parameter_type) if high is not None else self.high - if high is None: - raise_none_error("high") + low, high = self._check_param(low, high) prob = (value - low) / (high - low) broadcast_shape = self.shape(prob) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) @@ -321,12 +305,8 @@ class Uniform(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - low = self.cast(low, self.parameter_type) if low is not None else self.low - if low is None: - raise_none_error("low") - high = self.cast(high, self.parameter_type) if high is not None else self.high - if high is None: - raise_none_error("high") + self.checktuple(shape, 'shape') + low, high = self._check_param(low, high) broadcast_shape = self.shape(low + high) origin_shape = shape + broadcast_shape if origin_shape == (): diff --git a/tests/st/ops/ascend/test_bijector/test_scalar_affine.py b/tests/st/ops/ascend/test_bijector/test_scalar_affine.py index 137f5e0f05..11eda97976 100644 --- a/tests/st/ops/ascend/test_bijector/test_scalar_affine.py +++ b/tests/st/ops/ascend/test_bijector/test_scalar_affine.py @@ -75,7 +75,7 @@ def test_forward_jacobian(): forward_jacobian = Net2() x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) ans = forward_jacobian(x) - expected = np.log([2.0, 2.0, 2.0, 2.0]) + expected = np.log([2.0]) tol = 1e-6 assert (np.abs(ans.asnumpy() - expected) < tol).all() @@ -94,6 +94,6 @@ def test_backward_jacobian(): backward_jacobian = Net3() x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) ans = backward_jacobian(x) - expected = np.log([0.5, 0.5, 0.5, 0.5]) + expected = np.log([0.5]) tol = 1e-6 assert (np.abs(ans.asnumpy() - expected) < tol).all() diff --git a/tests/st/ops/ascend/test_bijector/test_softplus.py b/tests/st/ops/ascend/test_bijector/test_softplus.py index 9bf33aa254..0a0909a261 100644 --- a/tests/st/ops/ascend/test_bijector/test_softplus.py +++ b/tests/st/ops/ascend/test_bijector/test_softplus.py @@ -20,7 +20,7 @@ import mindspore.nn.probability.bijector as msb from mindspore import Tensor from mindspore import dtype -context.set_context(device_target="Ascend") +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell): """ diff --git a/tests/st/ops/ascend/test_distribution/test_uniform.py b/tests/st/ops/ascend/test_distribution/test_uniform.py index 5e54f2cdcc..6892d49e30 100644 --- a/tests/st/ops/ascend/test_distribution/test_uniform.py +++ b/tests/st/ops/ascend/test_distribution/test_uniform.py @@ -88,7 +88,7 @@ def test_kl_loss(): high_a = 1.5 low_b = -1.0 high_b = 2.0 - expect_kl_loss = np.log(high_b - low_b) / np.log(high_a - low_a) + expect_kl_loss = np.log(high_b - low_b) - np.log(high_a - low_a) kl = KL() output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) tol = 1e-6