|
|
|
@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
|
|
|
|
raise_none_error
|
|
|
|
|
from ._utils.utils import CheckTensor, CheckTuple
|
|
|
|
|
|
|
|
|
|
class Uniform(Distribution):
|
|
|
|
|
"""
|
|
|
|
@ -129,6 +130,9 @@ class Uniform(Distribution):
|
|
|
|
|
self.zeroslike = P.ZerosLike()
|
|
|
|
|
self.uniform = C.uniform
|
|
|
|
|
|
|
|
|
|
self.checktensor = CheckTensor()
|
|
|
|
|
self.checktuple = CheckTuple()
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
if self.is_scalar_batch:
|
|
|
|
|
str_info = f'low = {self.low}, high = {self.high}'
|
|
|
|
@ -136,6 +140,25 @@ class Uniform(Distribution):
|
|
|
|
|
str_info = f'batch_shape = {self._broadcast_shape}'
|
|
|
|
|
return str_info
|
|
|
|
|
|
|
|
|
|
def _check_param(self, low, high):
|
|
|
|
|
"""
|
|
|
|
|
Check availablity of distribution specific args low and high.
|
|
|
|
|
"""
|
|
|
|
|
if low is not None:
|
|
|
|
|
self.checktensor(low, 'low')
|
|
|
|
|
low = self.cast(low, self.parameter_type)
|
|
|
|
|
else:
|
|
|
|
|
low = self.low if self.low is not None else raise_none_error('low')
|
|
|
|
|
if high is not None:
|
|
|
|
|
self.checktensor(high, 'high')
|
|
|
|
|
high = self.cast(high, self.parameter_type)
|
|
|
|
|
else:
|
|
|
|
|
high = self.high if self.high is not None else raise_none_error('high')
|
|
|
|
|
batch_shape = self.shape(high - low)
|
|
|
|
|
high = high * self.fill(self.dtype, batch_shape, 1.0)
|
|
|
|
|
low = low * self.fill(self.dtype, batch_shape, 1.0)
|
|
|
|
|
return low, high
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def low(self):
|
|
|
|
|
"""
|
|
|
|
@ -156,12 +179,7 @@ class Uniform(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
range(U) = high -low
|
|
|
|
|
"""
|
|
|
|
|
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
|
|
|
|
if low is None:
|
|
|
|
|
raise_none_error("low")
|
|
|
|
|
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
|
|
|
if high is None:
|
|
|
|
|
raise_none_error("high")
|
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
|
return high - low
|
|
|
|
|
|
|
|
|
|
def _mean(self, low=None, high=None):
|
|
|
|
@ -169,12 +187,7 @@ class Uniform(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
MEAN(U) = \frac{low + high}{2}.
|
|
|
|
|
"""
|
|
|
|
|
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
|
|
|
|
if low is None:
|
|
|
|
|
raise_none_error("low")
|
|
|
|
|
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
|
|
|
if high is None:
|
|
|
|
|
raise_none_error("high")
|
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
|
return (low + high) / 2.
|
|
|
|
|
|
|
|
|
|
def _var(self, low=None, high=None):
|
|
|
|
@ -182,12 +195,7 @@ class Uniform(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
VAR(U) = \frac{(high -low) ^ 2}{12}.
|
|
|
|
|
"""
|
|
|
|
|
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
|
|
|
|
if low is None:
|
|
|
|
|
raise_none_error("low")
|
|
|
|
|
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
|
|
|
if high is None:
|
|
|
|
|
raise_none_error("high")
|
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
|
return self.sq(high - low) / 12.0
|
|
|
|
|
|
|
|
|
|
def _entropy(self, low=None, high=None):
|
|
|
|
@ -195,15 +203,10 @@ class Uniform(Distribution):
|
|
|
|
|
.. math::
|
|
|
|
|
H(U) = \log(high - low).
|
|
|
|
|
"""
|
|
|
|
|
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
|
|
|
|
if low is None:
|
|
|
|
|
raise_none_error("low")
|
|
|
|
|
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
|
|
|
if high is None:
|
|
|
|
|
raise_none_error("high")
|
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
|
return self.log(high - low)
|
|
|
|
|
|
|
|
|
|
def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
|
|
|
|
|
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
|
|
|
|
|
"""
|
|
|
|
|
Evaluate cross_entropy between Uniform distributoins.
|
|
|
|
|
|
|
|
|
@ -215,7 +218,7 @@ class Uniform(Distribution):
|
|
|
|
|
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
|
|
|
|
"""
|
|
|
|
|
check_distribution_name(dist, 'Uniform')
|
|
|
|
|
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
|
|
|
|
|
return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high)
|
|
|
|
|
|
|
|
|
|
def _prob(self, value, low=None, high=None):
|
|
|
|
|
r"""
|
|
|
|
@ -231,15 +234,9 @@ class Uniform(Distribution):
|
|
|
|
|
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
|
|
|
|
|
pdf(x) = 0 if x > high;
|
|
|
|
|
"""
|
|
|
|
|
if value is None:
|
|
|
|
|
raise_none_error("value")
|
|
|
|
|
self.checktensor(value, 'value')
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
|
|
|
|
if low is None:
|
|
|
|
|
raise_none_error("low")
|
|
|
|
|
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
|
|
|
if high is None:
|
|
|
|
|
raise_none_error("high")
|
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
|
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
|
|
|
|
prob = self.exp(neg_ones * self.log(high - low))
|
|
|
|
|
broadcast_shape = self.shape(prob)
|
|
|
|
@ -249,7 +246,7 @@ class Uniform(Distribution):
|
|
|
|
|
less_than_low = self.select(comp_lo, zeros, prob)
|
|
|
|
|
return self.select(comp_hi, less_than_low, zeros)
|
|
|
|
|
|
|
|
|
|
def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None):
|
|
|
|
|
def _kl_loss(self, dist, low_b, high_b, low=None, high=None):
|
|
|
|
|
"""
|
|
|
|
|
Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
|
|
|
|
|
|
|
|
|
@ -261,19 +258,12 @@ class Uniform(Distribution):
|
|
|
|
|
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
|
|
|
|
"""
|
|
|
|
|
check_distribution_name(dist, 'Uniform')
|
|
|
|
|
if low_b is None:
|
|
|
|
|
raise_none_error("low_b")
|
|
|
|
|
if high_b is None:
|
|
|
|
|
raise_none_error("high_b")
|
|
|
|
|
self.checktensor(low_b, 'low_b')
|
|
|
|
|
low_b = self.cast(low_b, self.parameter_type)
|
|
|
|
|
self.checktensor(high_b, 'high_b')
|
|
|
|
|
high_b = self.cast(high_b, self.parameter_type)
|
|
|
|
|
low_a = self.cast(low_a, self.parameter_type) if low_a is not None else self.low
|
|
|
|
|
if low_a is None:
|
|
|
|
|
raise_none_error("low_a")
|
|
|
|
|
high_a = self.cast(high_a, self.parameter_type) if high_a is not None else self.high
|
|
|
|
|
if high_a is None:
|
|
|
|
|
raise_none_error("high_a")
|
|
|
|
|
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
|
|
|
|
|
low_a, high_a = self._check_param(low, high)
|
|
|
|
|
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
|
|
|
|
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
|
|
|
|
|
return self.select(comp, kl, self.log(self.zeroslike(kl)))
|
|
|
|
|
|
|
|
|
@ -291,15 +281,9 @@ class Uniform(Distribution):
|
|
|
|
|
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
|
|
|
|
|
cdf(x) = 1 if x > high;
|
|
|
|
|
"""
|
|
|
|
|
if value is None:
|
|
|
|
|
raise_none_error("value")
|
|
|
|
|
self.checktensor(value, 'value')
|
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
|
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
|
|
|
|
if low is None:
|
|
|
|
|
raise_none_error("low")
|
|
|
|
|
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
|
|
|
if high is None:
|
|
|
|
|
raise_none_error("high")
|
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
|
prob = (value - low) / (high - low)
|
|
|
|
|
broadcast_shape = self.shape(prob)
|
|
|
|
|
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
|
|
|
@ -321,12 +305,8 @@ class Uniform(Distribution):
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, shape is shape + batch_shape.
|
|
|
|
|
"""
|
|
|
|
|
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
|
|
|
|
if low is None:
|
|
|
|
|
raise_none_error("low")
|
|
|
|
|
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
|
|
|
if high is None:
|
|
|
|
|
raise_none_error("high")
|
|
|
|
|
self.checktuple(shape, 'shape')
|
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
|
broadcast_shape = self.shape(low + high)
|
|
|
|
|
origin_shape = shape + broadcast_shape
|
|
|
|
|
if origin_shape == ():
|
|
|
|
|