|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
|
|
|
|
raise_none_error
|
|
|
|
|
raise_none_error, common_dtype
|
|
|
|
|
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
|
|
|
|
|
|
|
|
|
|
class Normal(Distribution):
|
|
|
|
@ -104,7 +104,7 @@ class Normal(Distribution):
|
|
|
|
|
valid_dtype = mstype.float_type
|
|
|
|
|
check_type(dtype, valid_dtype, type(self).__name__)
|
|
|
|
|
super(Normal, self).__init__(seed, dtype, name, param)
|
|
|
|
|
self.parameter_type = dtype
|
|
|
|
|
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype)
|
|
|
|
|
if mean is not None and sd is not None:
|
|
|
|
|
self._mean_value = cast_to_tensor(mean, self.parameter_type)
|
|
|
|
|
self._sd_value = cast_to_tensor(sd, self.parameter_type)
|
|
|
|
@ -126,6 +126,8 @@ class Normal(Distribution):
|
|
|
|
|
self.sq = P.Square()
|
|
|
|
|
self.sqrt = P.Sqrt()
|
|
|
|
|
self.zeroslike = P.ZerosLike()
|
|
|
|
|
self.dtypeop = P.DType()
|
|
|
|
|
self.sametypeshape = P.SameTypeShape()
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
if self.is_scalar_batch:
|
|
|
|
@ -143,7 +145,6 @@ class Normal(Distribution):
|
|
|
|
|
self.checktensor(mean, 'mean')
|
|
|
|
|
else:
|
|
|
|
|
mean = self.checktensor(mean, 'mean')
|
|
|
|
|
mean = self.cast(mean, self.parameter_type)
|
|
|
|
|
else:
|
|
|
|
|
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
|
|
|
|
|
if sd is not None:
|
|
|
|
@ -151,12 +152,14 @@ class Normal(Distribution):
|
|
|
|
|
self.checktensor(sd, 'sd')
|
|
|
|
|
else:
|
|
|
|
|
sd = self.checktensor(sd, 'sd')
|
|
|
|
|
sd = self.cast(sd, self.parameter_type)
|
|
|
|
|
else:
|
|
|
|
|
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
|
|
|
|
|
batch_shape = self.shape(mean + sd)
|
|
|
|
|
mean = mean * self.fill(self.dtype, batch_shape, 1.0)
|
|
|
|
|
sd = sd * self.fill(self.dtype, batch_shape, 1.0)
|
|
|
|
|
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
|
|
|
|
|
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
|
|
|
|
|
self.sametypeshape(mean, sd)
|
|
|
|
|
mean = self.cast(mean, self.parameter_type)
|
|
|
|
|
sd = self.cast(sd, self.parameter_type)
|
|
|
|
|
return mean, sd
|
|
|
|
|
|
|
|
|
|
def _mean(self, mean=None, sd=None):
|
|
|
|
|