edited common_dtype and check_param dtype logic

pull/6020/head
Xun Deng 5 years ago
parent 600704ddde
commit 67325d63b0

@ -110,6 +110,8 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded. Notes: String parameters are excluded.
""" """
for value in params.values(): for value in params.values():
if value is None:
continue
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch return params['distribution'].is_scalar_batch
if isinstance(value, Parameter): if isinstance(value, Parameter):
@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer):
return x return x
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): def set_param_type(args, hint_type):
""" """
check if arg_a and arg_b have the same dtype. Find the common type among arguments.
Args:
args (dict): dictionary of arguments, {'name':value}.
hint_type (mindspore.dtype): hint type to return.
Raises:
TypeError: if tensors in args are not the same dtype.
""" """
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'): common_dtype = None
if isinstance(arg_a, np.ndarray): for name, arg in args.items():
a_dtype = mstype.pytype_to_dtype(arg_a.dtype) if hasattr(arg, 'dtype'):
else: if isinstance(arg, np.ndarray):
a_dtype = arg_a.dtype cur_dtype = mstype.pytype_to_dtype(arg.dtype)
if isinstance(arg_b, np.ndarray): else:
b_dtype = mstype.pytype_to_dtype(arg_b.dtype) cur_dtype = arg.dtype
else: if common_dtype is None:
b_dtype = arg_b.dtype common_dtype = cur_dtype
if a_dtype != b_dtype: elif cur_dtype != common_dtype:
raise TypeError(f"{name_a} and {name_b} should have the same dtype.") raise TypeError(f"{name} should have the same dtype as other arguments.")
int_type = mstype.int_type + mstype.uint_type int_type = mstype.int_type + mstype.uint_type
if a_dtype in int_type or a_dtype == mstype.float64: if common_dtype in int_type or common_dtype == mstype.float64:
return mstype.float32 return mstype.float32
return a_dtype return hint_type if common_dtype is None else common_dtype
return hint_type

@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
from ._utils.custom_ops import exp_generic, log_generic, erf_generic from ._utils.custom_ops import exp_generic, log_generic, erf_generic
@ -119,13 +119,16 @@ class Bernoulli(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None: if probs is not None:
self._probs = cast_to_tensor(probs, mstype.float32) self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self.probs) check_prob(self.probs)
else: else:
self._probs = probs self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
@ -157,24 +160,12 @@ class Bernoulli(Distribution):
""" """
return self._probs return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args `probs1`.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(B) = probs1 MEAN(B) = probs1
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return probs1 return probs1
def _mode(self, probs1=None): def _mode(self, probs1=None):
@ -182,7 +173,7 @@ class Bernoulli(Distribution):
.. math:: .. math::
MODE(B) = 1 if probs1 > 0.5 else = 0 MODE(B) = 1 if probs1 > 0.5 else = 0
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0) zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0) ones = self.fill(prob_type, self.shape(probs1), 1.0)
@ -194,7 +185,7 @@ class Bernoulli(Distribution):
.. math:: .. math::
VAR(B) = probs1 * probs0 VAR(B) = probs1 * probs0
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.exp(self.log(probs0) + self.log(probs1)) return self.exp(self.log(probs0) + self.log(probs1))
@ -203,11 +194,11 @@ class Bernoulli(Distribution):
.. math:: .. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, dist, probs1_b, probs1_a=None): def _cross_entropy(self, dist, probs1_b, probs1=None):
""" """
Evaluate cross_entropy between Bernoulli distributions. Evaluate cross_entropy between Bernoulli distributions.
@ -217,7 +208,7 @@ class Bernoulli(Distribution):
probs1_a (Tensor): `probs1` of distribution a. Default: self.probs. probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
""" """
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
def _log_prob(self, value, probs1=None): def _log_prob(self, value, probs1=None):
r""" r"""
@ -233,7 +224,7 @@ class Bernoulli(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value) return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
@ -253,7 +244,7 @@ class Bernoulli(Distribution):
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0) value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
@ -264,7 +255,7 @@ class Bernoulli(Distribution):
less_than_zero = self.select(comp_zero, zeros, probs0) less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones) return self.select(comp_one, less_than_zero, ones)
def _kl_loss(self, dist, probs1_b, probs1_a=None): def _kl_loss(self, dist, probs1_b, probs1=None):
r""" r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
@ -280,7 +271,7 @@ class Bernoulli(Distribution):
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1_a) probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Bernoulli(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

File diff suppressed because it is too large Load Diff

@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
raise_none_error
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
class Exponential(Distribution): class Exponential(Distribution):
@ -121,15 +120,19 @@ class Exponential(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype self.parameter_type = set_param_type({'rate': rate}, self.dtype)
if rate is not None: if rate is not None:
self._rate = cast_to_tensor(rate, self.parameter_type) self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate") check_greater_zero(self._rate, "rate")
else: else:
self._rate = rate self._rate = rate
self.default_parameters = [self.rate]
self.parameter_names = ['rate']
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
@ -156,28 +159,16 @@ class Exponential(Distribution):
@property @property
def rate(self): def rate(self):
""" """
Return rate of the distribution. Return `rate` of the distribution.
""" """
return self._rate return self._rate
def _check_param(self, rate):
"""
Check availablity of distribution specific argument `rate`.
"""
if rate is not None:
if self.context_mode == 0:
self.checktensor(rate, 'rate')
else:
rate = self.checktensor(rate, 'rate')
return self.cast(rate, self.parameter_type)
return self.rate if self.rate is not None else raise_none_error('rate')
def _mean(self, rate=None): def _mean(self, rate=None):
r""" r"""
.. math:: .. math::
MEAN(EXP) = \frac{1.0}{\lambda}. MEAN(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return 1.0 / rate return 1.0 / rate
def _mode(self, rate=None): def _mode(self, rate=None):
@ -185,7 +176,7 @@ class Exponential(Distribution):
.. math:: .. math::
MODE(EXP) = 0. MODE(EXP) = 0.
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return self.fill(self.dtype, self.shape(rate), 0.) return self.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, rate=None): def _sd(self, rate=None):
@ -193,7 +184,7 @@ class Exponential(Distribution):
.. math:: .. math::
sd(EXP) = \frac{1.0}{\lambda}. sd(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return 1.0 / rate return 1.0 / rate
def _entropy(self, rate=None): def _entropy(self, rate=None):
@ -201,7 +192,7 @@ class Exponential(Distribution):
.. math:: .. math::
H(Exp) = 1 - \log(\lambda). H(Exp) = 1 - \log(\lambda).
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return 1.0 - self.log(rate) return 1.0 - self.log(rate)
def _cross_entropy(self, dist, rate_b, rate=None): def _cross_entropy(self, dist, rate_b, rate=None):
@ -234,7 +225,7 @@ class Exponential(Distribution):
""" """
value = self._check_value(value, "value") value = self._check_value(value, "value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param_type(rate)
prob = self.log(rate) - rate * value prob = self.log(rate) - rate * value
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf) neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
@ -257,7 +248,7 @@ class Exponential(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param_type(rate)
cdf = 1.0 - self.exp(-1. * rate * value) cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@ -279,7 +270,7 @@ class Exponential(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param_type(rate)
sf = -1. * rate * value sf = -1. * rate * value
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0) zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@ -297,7 +288,7 @@ class Exponential(Distribution):
check_distribution_name(dist, 'Exponential') check_distribution_name(dist, 'Exponential')
rate_b = self._check_value(rate_b, 'rate_b') rate_b = self._check_value(rate_b, 'rate_b')
rate_b = self.cast(rate_b, self.parameter_type) rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self._check_param(rate) rate_a = self._check_param_type(rate)
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
def _sample(self, shape=(), rate=None): def _sample(self, shape=(), rate=None):
@ -312,7 +303,7 @@ class Exponential(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
rate = self._check_param(rate) rate = self._check_param_type(rate)
origin_shape = shape + self.shape(rate) origin_shape = shape + self.shape(rate)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error set_param_type
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
@ -123,13 +123,16 @@ class Geometric(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None: if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type) self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs) check_prob(self._probs)
else: else:
self._probs = probs self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
@ -164,24 +167,12 @@ class Geometric(Distribution):
""" """
return self._probs return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1} MEAN(Geo) = \fratc{1 - probs1}{probs1}
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return (1. - probs1) / probs1 return (1. - probs1) / probs1
def _mode(self, probs1=None): def _mode(self, probs1=None):
@ -189,7 +180,7 @@ class Geometric(Distribution):
.. math:: .. math::
MODE(Geo) = 0 MODE(Geo) = 0
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
def _var(self, probs1=None): def _var(self, probs1=None):
@ -197,7 +188,7 @@ class Geometric(Distribution):
.. math:: .. math::
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return (1.0 - probs1) / self.sq(probs1) return (1.0 - probs1) / self.sq(probs1)
def _entropy(self, probs1=None): def _entropy(self, probs1=None):
@ -205,7 +196,7 @@ class Geometric(Distribution):
.. math:: .. math::
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
@ -236,7 +227,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@ -258,7 +249,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0) cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
@ -280,7 +271,7 @@ class Geometric(Distribution):
check_distribution_name(dist, 'Geometric') check_distribution_name(dist, 'Geometric')
probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1) probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Geometric(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error, common_dtype set_param_type
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
class Normal(Distribution): class Normal(Distribution):
@ -127,14 +127,17 @@ class Normal(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype) self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype)
if mean is not None and sd is not None: if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type) self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type) self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation") check_greater_zero(self._sd_value, "Standard deviation")
else: else:
self._mean_value = mean self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type)
self._sd_value = sd self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type)
self.default_parameters = [self._mean_value, self._sd_value]
self.parameter_names = ['mean', 'sd']
#ops needed for the class #ops needed for the class
self.exp = exp_generic self.exp = exp_generic
@ -159,51 +162,25 @@ class Normal(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}' str_info = f'batch_shape = {self._broadcast_shape}'
return str_info return str_info
def _check_param(self, mean, sd):
"""
Check availablity of distribution specific args `mean` and `sd`.
"""
if mean is not None:
if self.context_mode == 0:
self.checktensor(mean, 'mean')
else:
mean = self.checktensor(mean, 'mean')
else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None:
if self.context_mode == 0:
self.checktensor(sd, 'sd')
else:
sd = self.checktensor(sd, 'sd')
else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
batch_shape = self.shape(mean + sd)
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
self.sametypeshape(mean, sd)
mean = self.cast(mean, self.parameter_type)
sd = self.cast(sd, self.parameter_type)
return mean, sd
def _mean(self, mean=None, sd=None): def _mean(self, mean=None, sd=None):
""" """
The mean of the distribution. The mean of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return mean return mean
def _mode(self, mean=None, sd=None): def _mode(self, mean=None, sd=None):
""" """
The mode of the distribution. The mode of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return mean return mean
def _sd(self, mean=None, sd=None): def _sd(self, mean=None, sd=None):
""" """
The standard deviation of the distribution. The standard deviation of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return sd return sd
def _entropy(self, mean=None, sd=None): def _entropy(self, mean=None, sd=None):
@ -213,7 +190,7 @@ class Normal(Distribution):
.. math:: .. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
@ -244,7 +221,7 @@ class Normal(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
return unnormalized_log_prob + neg_normalization return unnormalized_log_prob + neg_normalization
@ -263,7 +240,7 @@ class Normal(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
sqrt2 = self.sqrt(self.const(2.0)) sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2) adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted)) return 0.5 * (1.0 + self.erf(adjusted))
@ -288,7 +265,7 @@ class Normal(Distribution):
sd_b = self._check_value(sd_b, 'sd_b') sd_b = self._check_value(sd_b, 'sd_b')
mean_b = self.cast(mean_b, self.parameter_type) mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type) sd_b = self.cast(sd_b, self.parameter_type)
mean_a, sd_a = self._check_param(mean, sd) mean_a, sd_a = self._check_param_type(mean, sd)
diff_log_scale = self.log(sd_a) - self.log(sd_b) diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
@ -306,7 +283,7 @@ class Normal(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
batch_shape = self.shape(mean + sd) batch_shape = self.shape(mean + sd)
origin_shape = shape + batch_shape origin_shape = shape + batch_shape
if origin_shape == (): if origin_shape == ():

@ -18,7 +18,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error, common_dtype set_param_type
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
class Uniform(Distribution): class Uniform(Distribution):
@ -126,14 +126,17 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype) self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype)
if low is not None and high is not None: if low is not None and high is not None:
self._low = cast_to_tensor(low, dtype) self._low = cast_to_tensor(low, self.parameter_type)
self._high = cast_to_tensor(high, dtype) self._high = cast_to_tensor(high, self.parameter_type)
check_greater(self.low, self.high, "low value", "high value") check_greater(self.low, self.high, "low value", "high value")
else: else:
self._low = low self._low = low if low is None else cast_to_tensor(low, self.parameter_type)
self._high = high self._high = high if high is None else cast_to_tensor(high, self.parameter_type)
self.default_parameters = [self.low, self.high]
self.parameter_names = ['low', 'high']
# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
@ -162,32 +165,6 @@ class Uniform(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}' str_info = f'batch_shape = {self._broadcast_shape}'
return str_info return str_info
def _check_param(self, low, high):
"""
Check availablity of distribution specific args `low` and `high`.
"""
if low is not None:
if self.context_mode == 0:
self.checktensor(low, 'low')
else:
low = self.checktensor(low, 'low')
else:
low = self.low if self.low is not None else raise_none_error('low')
if high is not None:
if self.context_mode == 0:
self.checktensor(high, 'high')
else:
high = self.checktensor(high, 'high')
else:
high = self.high if self.high is not None else raise_none_error('high')
batch_shape = self.shape(high - low)
high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
self.sametypeshape(high, low)
low = self.cast(low, self.parameter_type)
high = self.cast(high, self.parameter_type)
return low, high
@property @property
def low(self): def low(self):
""" """
@ -209,7 +186,7 @@ class Uniform(Distribution):
.. math:: .. math::
range(U) = high -low range(U) = high -low
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return high - low return high - low
def _mean(self, low=None, high=None): def _mean(self, low=None, high=None):
@ -217,7 +194,7 @@ class Uniform(Distribution):
.. math:: .. math::
MEAN(U) = \frac{low + high}{2}. MEAN(U) = \frac{low + high}{2}.
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return (low + high) / 2. return (low + high) / 2.
def _var(self, low=None, high=None): def _var(self, low=None, high=None):
@ -225,7 +202,7 @@ class Uniform(Distribution):
.. math:: .. math::
VAR(U) = \frac{(high -low) ^ 2}{12}. VAR(U) = \frac{(high -low) ^ 2}{12}.
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return self.sq(high - low) / 12.0 return self.sq(high - low) / 12.0
def _entropy(self, low=None, high=None): def _entropy(self, low=None, high=None):
@ -233,7 +210,7 @@ class Uniform(Distribution):
.. math:: .. math::
H(U) = \log(high - low). H(U) = \log(high - low).
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return self.log(high - low) return self.log(high - low)
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
@ -266,7 +243,7 @@ class Uniform(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0) neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low)) prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
@ -292,7 +269,7 @@ class Uniform(Distribution):
low_b = self.cast(low_b, self.parameter_type) low_b = self.cast(low_b, self.parameter_type)
high_b = self._check_value(high_b, 'high_b') high_b = self._check_value(high_b, 'high_b')
high_b = self.cast(high_b, self.parameter_type) high_b = self.cast(high_b, self.parameter_type)
low_a, high_a = self._check_param(low, high) low_a, high_a = self._check_param_type(low, high)
kl = self.log(high_b - low_b) - self.log(high_a - low_a) kl = self.log(high_b - low_b) - self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl))) return self.select(comp, kl, self.log(self.zeroslike(kl)))
@ -313,7 +290,7 @@ class Uniform(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
prob = (value - low) / (high - low) prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
@ -336,7 +313,7 @@ class Uniform(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape origin_shape = shape + broadcast_shape
if origin_shape == (): if origin_shape == ():

@ -0,0 +1,182 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test util functions used in distribution classes.
"""
import numpy as np
import pytest
from mindspore.nn.cell import Cell
from mindspore import context
from mindspore import dtype
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
cast_to_tensor, CheckTuple, CheckTensor
def test_set_param_type():
"""
Test set_param_type function.
"""
tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
tensor_int32 = Tensor(0.1, dtype=dtype.int32)
array_fp32 = np.array(1.0).astype(np.float32)
array_fp64 = np.array(1.0).astype(np.float64)
array_int32 = np.array(1.0).astype(np.int32)
dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
dict7 = {'a': 1.0}
dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
ans1 = set_param_type(dict1, dtype.float16)
assert ans1 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict2, dtype.float32)
ans3 = set_param_type(dict3, dtype.float16)
assert ans3 == dtype.float32
ans4 = set_param_type(dict4, dtype.float16)
assert ans4 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict5, dtype.float32)
with pytest.raises(TypeError):
set_param_type(dict6, dtype.float32)
ans7 = set_param_type(dict7, dtype.float32)
assert ans7 == dtype.float32
ans8 = set_param_type(dict8, dtype.float32)
assert ans8 == dtype.float32
ans9 = set_param_type(dict9, dtype.float32)
assert ans9 == dtype.float16
ans10 = set_param_type(dict10, dtype.float32)
assert ans10 == dtype.float32
ans11 = set_param_type(dict11, dtype.float32)
assert ans11 == dtype.float32
def test_cast_to_tensor():
"""
Test cast_to_tensor.
"""
with pytest.raises(ValueError):
cast_to_tensor(None, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor(True, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor('tensor', dtype.float32)
ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
assert isinstance(ans1, Parameter)
ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
assert isinstance(ans2, Tensor)
ans3 = cast_to_tensor([1.0, 2.0])
assert isinstance(ans3, Tensor)
ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
assert isinstance(ans4, Tensor)
ans5 = cast_to_tensor(0.1, dtype.float32)
assert isinstance(ans5, Tensor)
ans6 = cast_to_tensor(1, dtype.float32)
assert isinstance(ans6, Tensor)
class Net(Cell):
"""
Test class: CheckTuple.
"""
def __init__(self, value):
super(Net, self).__init__()
self.checktuple = CheckTuple()
self.value = value
def construct(self, value=None):
if value is None:
return self.checktuple(self.value, 'input')
return self.checktuple(value, 'input')
def test_check_tuple():
"""
Test CheckTuple.
"""
net1 = Net((1, 2, 3))
ans1 = net1()
assert isinstance(ans1, tuple)
with pytest.raises(TypeError):
net2 = Net('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net((1, 2, 3))
ans3 = net3()
assert isinstance(ans3, tuple)
with pytest.raises(TypeError):
net4 = Net('tuple')
net4()
class Net1(Cell):
"""
Test class: CheckTensor.
"""
def __init__(self, value):
super(Net1, self).__init__()
self.checktensor = CheckTensor()
self.value = value
self.context = context.get_context('mode')
def construct(self, value=None):
value = self.value if value is None else value
if self.context == 0:
self.checktensor(value, 'input')
return value
return self.checktensor(value, 'input')
def test_check_tensor():
"""
Test CheckTensor.
"""
value = Tensor(0.1, dtype=dtype.float32)
net1 = Net1(value)
ans1 = net1()
assert isinstance(ans1, Tensor)
ans1 = net1(value)
assert isinstance(ans1, Tensor)
with pytest.raises(TypeError):
net2 = Net1('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net1(value)
ans3 = net3()
assert isinstance(ans3, Tensor)
ans3 = net3(value)
assert isinstance(ans3, Tensor)
with pytest.raises(TypeError):
net4 = Net1('tuple')
net4()
Loading…
Cancel
Save