edited common_dtype and check_param dtype logic

pull/6020/head
Xun Deng 5 years ago
parent 600704ddde
commit 67325d63b0

@ -110,6 +110,8 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded.
"""
for value in params.values():
if value is None:
continue
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch
if isinstance(value, Parameter):
@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer):
return x
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
def set_param_type(args, hint_type):
"""
check if arg_a and arg_b have the same dtype.
Find the common type among arguments.
Args:
args (dict): dictionary of arguments, {'name':value}.
hint_type (mindspore.dtype): hint type to return.
Raises:
TypeError: if tensors in args are not the same dtype.
"""
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'):
if isinstance(arg_a, np.ndarray):
a_dtype = mstype.pytype_to_dtype(arg_a.dtype)
else:
a_dtype = arg_a.dtype
if isinstance(arg_b, np.ndarray):
b_dtype = mstype.pytype_to_dtype(arg_b.dtype)
else:
b_dtype = arg_b.dtype
if a_dtype != b_dtype:
raise TypeError(f"{name_a} and {name_b} should have the same dtype.")
int_type = mstype.int_type + mstype.uint_type
if a_dtype in int_type or a_dtype == mstype.float64:
return mstype.float32
return a_dtype
return hint_type
common_dtype = None
for name, arg in args.items():
if hasattr(arg, 'dtype'):
if isinstance(arg, np.ndarray):
cur_dtype = mstype.pytype_to_dtype(arg.dtype)
else:
cur_dtype = arg.dtype
if common_dtype is None:
common_dtype = cur_dtype
elif cur_dtype != common_dtype:
raise TypeError(f"{name} should have the same dtype as other arguments.")
int_type = mstype.int_type + mstype.uint_type
if common_dtype in int_type or common_dtype == mstype.float64:
return mstype.float32
return hint_type if common_dtype is None else common_dtype

@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
from ._utils.custom_ops import exp_generic, log_generic, erf_generic
@ -119,13 +119,16 @@ class Bernoulli(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, mstype.float32)
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self.probs)
else:
self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
# ops needed for the class
self.exp = exp_generic
self.log = log_generic
@ -157,24 +160,12 @@ class Bernoulli(Distribution):
"""
return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args `probs1`.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None):
r"""
.. math::
MEAN(B) = probs1
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return probs1
def _mode(self, probs1=None):
@ -182,7 +173,7 @@ class Bernoulli(Distribution):
.. math::
MODE(B) = 1 if probs1 > 0.5 else = 0
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0)
@ -194,7 +185,7 @@ class Bernoulli(Distribution):
.. math::
VAR(B) = probs1 * probs0
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return self.exp(self.log(probs0) + self.log(probs1))
@ -203,11 +194,11 @@ class Bernoulli(Distribution):
.. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
def _cross_entropy(self, dist, probs1_b, probs1=None):
"""
Evaluate cross_entropy between Bernoulli distributions.
@ -217,7 +208,7 @@ class Bernoulli(Distribution):
probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
"""
check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
def _log_prob(self, value, probs1=None):
r"""
@ -233,7 +224,7 @@ class Bernoulli(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
@ -253,7 +244,7 @@ class Bernoulli(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
@ -264,7 +255,7 @@ class Bernoulli(Distribution):
less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones)
def _kl_loss(self, dist, probs1_b, probs1_a=None):
def _kl_loss(self, dist, probs1_b, probs1=None):
r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
@ -280,7 +271,7 @@ class Bernoulli(Distribution):
check_distribution_name(dist, 'Bernoulli')
probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1_a)
probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Bernoulli(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)

File diff suppressed because it is too large Load Diff

@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
from ._utils.custom_ops import exp_generic, log_generic
class Exponential(Distribution):
@ -121,15 +120,19 @@ class Exponential(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
self.parameter_type = set_param_type({'rate': rate}, self.dtype)
if rate is not None:
self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate")
else:
self._rate = rate
self.default_parameters = [self.rate]
self.parameter_names = ['rate']
self.minval = np.finfo(np.float).tiny
# ops needed for the class
self.exp = exp_generic
self.log = log_generic
@ -156,28 +159,16 @@ class Exponential(Distribution):
@property
def rate(self):
"""
Return rate of the distribution.
Return `rate` of the distribution.
"""
return self._rate
def _check_param(self, rate):
"""
Check availablity of distribution specific argument `rate`.
"""
if rate is not None:
if self.context_mode == 0:
self.checktensor(rate, 'rate')
else:
rate = self.checktensor(rate, 'rate')
return self.cast(rate, self.parameter_type)
return self.rate if self.rate is not None else raise_none_error('rate')
def _mean(self, rate=None):
r"""
.. math::
MEAN(EXP) = \frac{1.0}{\lambda}.
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return 1.0 / rate
def _mode(self, rate=None):
@ -185,7 +176,7 @@ class Exponential(Distribution):
.. math::
MODE(EXP) = 0.
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return self.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, rate=None):
@ -193,7 +184,7 @@ class Exponential(Distribution):
.. math::
sd(EXP) = \frac{1.0}{\lambda}.
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return 1.0 / rate
def _entropy(self, rate=None):
@ -201,7 +192,7 @@ class Exponential(Distribution):
.. math::
H(Exp) = 1 - \log(\lambda).
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return 1.0 - self.log(rate)
def _cross_entropy(self, dist, rate_b, rate=None):
@ -234,7 +225,7 @@ class Exponential(Distribution):
"""
value = self._check_value(value, "value")
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
rate = self._check_param_type(rate)
prob = self.log(rate) - rate * value
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
@ -257,7 +248,7 @@ class Exponential(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
rate = self._check_param_type(rate)
cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
@ -279,7 +270,7 @@ class Exponential(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
rate = self._check_param_type(rate)
sf = -1. * rate * value
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
comp = self.less(value, zeros)
@ -297,7 +288,7 @@ class Exponential(Distribution):
check_distribution_name(dist, 'Exponential')
rate_b = self._check_value(rate_b, 'rate_b')
rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self._check_param(rate)
rate_a = self._check_param_type(rate)
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
def _sample(self, shape=(), rate=None):
@ -312,7 +303,7 @@ class Exponential(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
rate = self._check_param(rate)
rate = self._check_param_type(rate)
origin_shape = shape + self.shape(rate)
if origin_shape == ():
sample_shape = (1,)

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error
set_param_type
from ._utils.custom_ops import exp_generic, log_generic
@ -123,13 +123,16 @@ class Geometric(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs)
else:
self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self.minval = np.finfo(np.float).tiny
# ops needed for the class
@ -164,24 +167,12 @@ class Geometric(Distribution):
"""
return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None):
r"""
.. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1}
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return (1. - probs1) / probs1
def _mode(self, probs1=None):
@ -189,7 +180,7 @@ class Geometric(Distribution):
.. math::
MODE(Geo) = 0
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
def _var(self, probs1=None):
@ -197,7 +188,7 @@ class Geometric(Distribution):
.. math::
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return (1.0 - probs1) / self.sq(probs1)
def _entropy(self, probs1=None):
@ -205,7 +196,7 @@ class Geometric(Distribution):
.. math::
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
@ -236,7 +227,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
@ -258,7 +249,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
@ -280,7 +271,7 @@ class Geometric(Distribution):
check_distribution_name(dist, 'Geometric')
probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1)
probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Geometric(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error, common_dtype
set_param_type
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
class Normal(Distribution):
@ -127,14 +127,17 @@ class Normal(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype)
self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype)
if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
self._sd_value = sd
self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type)
self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type)
self.default_parameters = [self._mean_value, self._sd_value]
self.parameter_names = ['mean', 'sd']
#ops needed for the class
self.exp = exp_generic
@ -159,51 +162,25 @@ class Normal(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}'
return str_info
def _check_param(self, mean, sd):
"""
Check availablity of distribution specific args `mean` and `sd`.
"""
if mean is not None:
if self.context_mode == 0:
self.checktensor(mean, 'mean')
else:
mean = self.checktensor(mean, 'mean')
else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None:
if self.context_mode == 0:
self.checktensor(sd, 'sd')
else:
sd = self.checktensor(sd, 'sd')
else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
batch_shape = self.shape(mean + sd)
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
self.sametypeshape(mean, sd)
mean = self.cast(mean, self.parameter_type)
sd = self.cast(sd, self.parameter_type)
return mean, sd
def _mean(self, mean=None, sd=None):
"""
The mean of the distribution.
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return mean
def _mode(self, mean=None, sd=None):
"""
The mode of the distribution.
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return mean
def _sd(self, mean=None, sd=None):
"""
The standard deviation of the distribution.
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return sd
def _entropy(self, mean=None, sd=None):
@ -213,7 +190,7 @@ class Normal(Distribution):
.. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
@ -244,7 +221,7 @@ class Normal(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
return unnormalized_log_prob + neg_normalization
@ -263,7 +240,7 @@ class Normal(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted))
@ -288,7 +265,7 @@ class Normal(Distribution):
sd_b = self._check_value(sd_b, 'sd_b')
mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type)
mean_a, sd_a = self._check_param(mean, sd)
mean_a, sd_a = self._check_param_type(mean, sd)
diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
@ -306,7 +283,7 @@ class Normal(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
batch_shape = self.shape(mean + sd)
origin_shape = shape + batch_shape
if origin_shape == ():

@ -18,7 +18,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error, common_dtype
set_param_type
from ._utils.custom_ops import exp_generic, log_generic
class Uniform(Distribution):
@ -126,14 +126,17 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype)
self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype)
if low is not None and high is not None:
self._low = cast_to_tensor(low, dtype)
self._high = cast_to_tensor(high, dtype)
self._low = cast_to_tensor(low, self.parameter_type)
self._high = cast_to_tensor(high, self.parameter_type)
check_greater(self.low, self.high, "low value", "high value")
else:
self._low = low
self._high = high
self._low = low if low is None else cast_to_tensor(low, self.parameter_type)
self._high = high if high is None else cast_to_tensor(high, self.parameter_type)
self.default_parameters = [self.low, self.high]
self.parameter_names = ['low', 'high']
# ops needed for the class
self.exp = exp_generic
@ -162,32 +165,6 @@ class Uniform(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}'
return str_info
def _check_param(self, low, high):
"""
Check availablity of distribution specific args `low` and `high`.
"""
if low is not None:
if self.context_mode == 0:
self.checktensor(low, 'low')
else:
low = self.checktensor(low, 'low')
else:
low = self.low if self.low is not None else raise_none_error('low')
if high is not None:
if self.context_mode == 0:
self.checktensor(high, 'high')
else:
high = self.checktensor(high, 'high')
else:
high = self.high if self.high is not None else raise_none_error('high')
batch_shape = self.shape(high - low)
high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
self.sametypeshape(high, low)
low = self.cast(low, self.parameter_type)
high = self.cast(high, self.parameter_type)
return low, high
@property
def low(self):
"""
@ -209,7 +186,7 @@ class Uniform(Distribution):
.. math::
range(U) = high -low
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return high - low
def _mean(self, low=None, high=None):
@ -217,7 +194,7 @@ class Uniform(Distribution):
.. math::
MEAN(U) = \frac{low + high}{2}.
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return (low + high) / 2.
def _var(self, low=None, high=None):
@ -225,7 +202,7 @@ class Uniform(Distribution):
.. math::
VAR(U) = \frac{(high -low) ^ 2}{12}.
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return self.sq(high - low) / 12.0
def _entropy(self, low=None, high=None):
@ -233,7 +210,7 @@ class Uniform(Distribution):
.. math::
H(U) = \log(high - low).
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return self.log(high - low)
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
@ -266,7 +243,7 @@ class Uniform(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob)
@ -292,7 +269,7 @@ class Uniform(Distribution):
low_b = self.cast(low_b, self.parameter_type)
high_b = self._check_value(high_b, 'high_b')
high_b = self.cast(high_b, self.parameter_type)
low_a, high_a = self._check_param(low, high)
low_a, high_a = self._check_param_type(low, high)
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl)))
@ -313,7 +290,7 @@ class Uniform(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
@ -336,7 +313,7 @@ class Uniform(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape
if origin_shape == ():

@ -0,0 +1,182 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test util functions used in distribution classes.
"""
import numpy as np
import pytest
from mindspore.nn.cell import Cell
from mindspore import context
from mindspore import dtype
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
cast_to_tensor, CheckTuple, CheckTensor
def test_set_param_type():
"""
Test set_param_type function.
"""
tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
tensor_int32 = Tensor(0.1, dtype=dtype.int32)
array_fp32 = np.array(1.0).astype(np.float32)
array_fp64 = np.array(1.0).astype(np.float64)
array_int32 = np.array(1.0).astype(np.int32)
dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
dict7 = {'a': 1.0}
dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
ans1 = set_param_type(dict1, dtype.float16)
assert ans1 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict2, dtype.float32)
ans3 = set_param_type(dict3, dtype.float16)
assert ans3 == dtype.float32
ans4 = set_param_type(dict4, dtype.float16)
assert ans4 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict5, dtype.float32)
with pytest.raises(TypeError):
set_param_type(dict6, dtype.float32)
ans7 = set_param_type(dict7, dtype.float32)
assert ans7 == dtype.float32
ans8 = set_param_type(dict8, dtype.float32)
assert ans8 == dtype.float32
ans9 = set_param_type(dict9, dtype.float32)
assert ans9 == dtype.float16
ans10 = set_param_type(dict10, dtype.float32)
assert ans10 == dtype.float32
ans11 = set_param_type(dict11, dtype.float32)
assert ans11 == dtype.float32
def test_cast_to_tensor():
"""
Test cast_to_tensor.
"""
with pytest.raises(ValueError):
cast_to_tensor(None, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor(True, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor('tensor', dtype.float32)
ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
assert isinstance(ans1, Parameter)
ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
assert isinstance(ans2, Tensor)
ans3 = cast_to_tensor([1.0, 2.0])
assert isinstance(ans3, Tensor)
ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
assert isinstance(ans4, Tensor)
ans5 = cast_to_tensor(0.1, dtype.float32)
assert isinstance(ans5, Tensor)
ans6 = cast_to_tensor(1, dtype.float32)
assert isinstance(ans6, Tensor)
class Net(Cell):
"""
Test class: CheckTuple.
"""
def __init__(self, value):
super(Net, self).__init__()
self.checktuple = CheckTuple()
self.value = value
def construct(self, value=None):
if value is None:
return self.checktuple(self.value, 'input')
return self.checktuple(value, 'input')
def test_check_tuple():
"""
Test CheckTuple.
"""
net1 = Net((1, 2, 3))
ans1 = net1()
assert isinstance(ans1, tuple)
with pytest.raises(TypeError):
net2 = Net('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net((1, 2, 3))
ans3 = net3()
assert isinstance(ans3, tuple)
with pytest.raises(TypeError):
net4 = Net('tuple')
net4()
class Net1(Cell):
"""
Test class: CheckTensor.
"""
def __init__(self, value):
super(Net1, self).__init__()
self.checktensor = CheckTensor()
self.value = value
self.context = context.get_context('mode')
def construct(self, value=None):
value = self.value if value is None else value
if self.context == 0:
self.checktensor(value, 'input')
return value
return self.checktensor(value, 'input')
def test_check_tensor():
"""
Test CheckTensor.
"""
value = Tensor(0.1, dtype=dtype.float32)
net1 = Net1(value)
ans1 = net1()
assert isinstance(ans1, Tensor)
ans1 = net1(value)
assert isinstance(ans1, Tensor)
with pytest.raises(TypeError):
net2 = Net1('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net1(value)
ans3 = net3()
assert isinstance(ans3, Tensor)
ans3 = net3(value)
assert isinstance(ans3, Tensor)
with pytest.raises(TypeError):
net4 = Net1('tuple')
net4()
Loading…
Cancel
Save