|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore._checkparam import Validator
|
|
|
|
|
from .distribution import Distribution
|
|
|
|
|
from ._utils.utils import check_prob, check_distribution_name
|
|
|
|
|
from ._utils.utils import check_prob, check_distribution_name, clamp_probs
|
|
|
|
|
from ._utils.custom_ops import exp_generic, log_generic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -86,7 +86,6 @@ class Bernoulli(Distribution):
|
|
|
|
|
>>> ans = b2.mean(probs_a)
|
|
|
|
|
>>> print(ans.shape)
|
|
|
|
|
(1,)
|
|
|
|
|
>>> print(ans.shape)
|
|
|
|
|
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
|
|
|
|
|
>>> # Args:
|
|
|
|
|
>>> # dist (str): the name of the distribution. Only 'Bernoulli' is supported.
|
|
|
|
@ -132,7 +131,8 @@ class Bernoulli(Distribution):
|
|
|
|
|
param = dict(locals())
|
|
|
|
|
param['param_dict'] = {'probs': probs}
|
|
|
|
|
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
|
|
|
|
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
|
|
|
|
Validator.check_type_name(
|
|
|
|
|
"dtype", dtype, valid_dtype, type(self).__name__)
|
|
|
|
|
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
|
|
|
|
|
|
|
|
|
self._probs = self._add_parameter(probs, 'probs')
|
|
|
|
@ -241,6 +241,9 @@ class Bernoulli(Distribution):
|
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
|
value = self.cast(value, self.parameter_type)
|
|
|
|
|
probs1 = self._check_param_type(probs1)
|
|
|
|
|
|
|
|
|
|
# clamp value for numerical stability
|
|
|
|
|
probs1 = clamp_probs(probs1)
|
|
|
|
|
probs0 = 1.0 - probs1
|
|
|
|
|
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
|
|
|
|
|
|
|
|
@ -266,8 +269,10 @@ class Bernoulli(Distribution):
|
|
|
|
|
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
|
|
|
|
|
comp_zero = self.less(value, 0.0)
|
|
|
|
|
comp_one = self.less(value, 1.0)
|
|
|
|
|
zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0)
|
|
|
|
|
ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0)
|
|
|
|
|
zeros = self.fill(self.parameter_type, self.shape(
|
|
|
|
|
broadcast_shape_tensor), 0.0)
|
|
|
|
|
ones = self.fill(self.parameter_type, self.shape(
|
|
|
|
|
broadcast_shape_tensor), 1.0)
|
|
|
|
|
less_than_zero = self.select(comp_zero, zeros, probs0)
|
|
|
|
|
return self.select(comp_one, less_than_zero, ones)
|
|
|
|
|
|
|
|
|
|