|
|
@ -17,7 +17,7 @@ from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from .distribution import Distribution
|
|
|
|
from .distribution import Distribution
|
|
|
|
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
|
|
|
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
|
|
|
|
raise_none_error
|
|
|
|
raise_none_error
|
|
|
|
from ._utils.custom_ops import exp_generic, log_generic
|
|
|
|
from ._utils.custom_ops import exp_generic, log_generic
|
|
|
|
|
|
|
|
|
|
|
@ -101,12 +101,12 @@ class Uniform(Distribution):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
param = dict(locals())
|
|
|
|
param = dict(locals())
|
|
|
|
valid_dtype = mstype.float_type
|
|
|
|
valid_dtype = mstype.float_type
|
|
|
|
check_type(dtype, valid_dtype, "Uniform")
|
|
|
|
check_type(dtype, valid_dtype, type(self).__name__)
|
|
|
|
super(Uniform, self).__init__(seed, dtype, name, param)
|
|
|
|
super(Uniform, self).__init__(seed, dtype, name, param)
|
|
|
|
self.parameter_type = dtype
|
|
|
|
self.parameter_type = dtype
|
|
|
|
if low is not None and high is not None:
|
|
|
|
if low is not None and high is not None:
|
|
|
|
self._low = convert_to_batch(low, self.broadcast_shape, dtype)
|
|
|
|
self._low = cast_to_tensor(low, dtype)
|
|
|
|
self._high = convert_to_batch(high, self.broadcast_shape, dtype)
|
|
|
|
self._high = cast_to_tensor(high, dtype)
|
|
|
|
check_greater(self.low, self.high, "low value", "high value")
|
|
|
|
check_greater(self.low, self.high, "low value", "high value")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self._low = low
|
|
|
|
self._low = low
|
|
|
@ -142,12 +142,18 @@ class Uniform(Distribution):
|
|
|
|
Check availablity of distribution specific args low and high.
|
|
|
|
Check availablity of distribution specific args low and high.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if low is not None:
|
|
|
|
if low is not None:
|
|
|
|
self.checktensor(low, 'low')
|
|
|
|
if self.context_mode == 0:
|
|
|
|
|
|
|
|
self.checktensor(low, 'low')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
low = self.checktensor(low, 'low')
|
|
|
|
low = self.cast(low, self.parameter_type)
|
|
|
|
low = self.cast(low, self.parameter_type)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
low = self.low if self.low is not None else raise_none_error('low')
|
|
|
|
low = self.low if self.low is not None else raise_none_error('low')
|
|
|
|
if high is not None:
|
|
|
|
if high is not None:
|
|
|
|
self.checktensor(high, 'high')
|
|
|
|
if self.context_mode == 0:
|
|
|
|
|
|
|
|
self.checktensor(high, 'high')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
high = self.checktensor(high, 'high')
|
|
|
|
high = self.cast(high, self.parameter_type)
|
|
|
|
high = self.cast(high, self.parameter_type)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
high = self.high if self.high is not None else raise_none_error('high')
|
|
|
|
high = self.high if self.high is not None else raise_none_error('high')
|
|
|
@ -231,7 +237,7 @@ class Uniform(Distribution):
|
|
|
|
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
|
|
|
|
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
|
|
|
|
pdf(x) = 0 if x > high;
|
|
|
|
pdf(x) = 0 if x > high;
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self.checktensor(value, 'value')
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
|
|
|
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
|
|
@ -255,9 +261,9 @@ class Uniform(Distribution):
|
|
|
|
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
|
|
|
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
check_distribution_name(dist, 'Uniform')
|
|
|
|
check_distribution_name(dist, 'Uniform')
|
|
|
|
self.checktensor(low_b, 'low_b')
|
|
|
|
low_b = self._check_value(low_b, 'low_b')
|
|
|
|
low_b = self.cast(low_b, self.parameter_type)
|
|
|
|
low_b = self.cast(low_b, self.parameter_type)
|
|
|
|
self.checktensor(high_b, 'high_b')
|
|
|
|
high_b = self._check_value(high_b, 'high_b')
|
|
|
|
high_b = self.cast(high_b, self.parameter_type)
|
|
|
|
high_b = self.cast(high_b, self.parameter_type)
|
|
|
|
low_a, high_a = self._check_param(low, high)
|
|
|
|
low_a, high_a = self._check_param(low, high)
|
|
|
|
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
|
|
|
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
|
|
@ -278,7 +284,7 @@ class Uniform(Distribution):
|
|
|
|
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
|
|
|
|
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
|
|
|
|
cdf(x) = 1 if x > high;
|
|
|
|
cdf(x) = 1 if x > high;
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self.checktensor(value, 'value')
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
value = self.cast(value, self.dtype)
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
low, high = self._check_param(low, high)
|
|
|
|
prob = (value - low) / (high - low)
|
|
|
|
prob = (value - low) / (high - low)
|
|
|
|