fix checktensor in pynative mode

pull/5193/head
Xun Deng 5 years ago
parent b8da525fb1
commit 4fe8b3d395

@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Bijector""" """Bijector"""
from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import CheckTensor
from ..distribution import Distribution from ..distribution import Distribution
from ..distribution import TransformedDistribution from ..distribution import TransformedDistribution
@ -40,7 +42,7 @@ class Bijector(Cell):
Constructor of bijector class. Constructor of bijector class.
""" """
super(Bijector, self).__init__() super(Bijector, self).__init__()
validator.check_value_type('name', name, [str], 'Bijector') validator.check_value_type('name', name, [str], type(self).__name__)
validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name) validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name) validator.check_value_type('is_injective', is_injective, [bool], name)
self._name = name self._name = name
@ -53,6 +55,9 @@ class Bijector(Cell):
self._is_constant_jacobian = is_constant_jacobian self._is_constant_jacobian = is_constant_jacobian
self._is_injective = is_injective self._is_injective = is_injective
self.context_mode = context.get_context('mode')
self.checktensor = CheckTensor()
@property @property
def name(self): def name(self):
return self._name return self._name
@ -73,6 +78,15 @@ class Bijector(Cell):
def is_injective(self): def is_injective(self):
return self._is_injective return self._is_injective
def _check_value(self, value, name):
"""
Check availability fo value as a Tensor.
"""
if self.context_mode == 0:
self.checktensor(value, name)
return value
return self.checktensor(value, name)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
""" """
Forward transformation: transform the input value to another distribution. Forward transformation: transform the input value to another distribution.

@ -16,7 +16,6 @@
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ..distribution._utils.utils import CheckTensor
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from .bijector import Bijector from .bijector import Bijector
@ -66,8 +65,6 @@ class PowerTransform(Bijector):
self.log = log_generic self.log = log_generic
self.log1p = log1p_generic self.log1p = log1p_generic
self.checktensor = CheckTensor()
@property @property
def power(self): def power(self):
return self._power return self._power
@ -80,13 +77,13 @@ class PowerTransform(Bijector):
return shape return shape
def _forward(self, x): def _forward(self, x):
self.checktensor(x, 'value') x = self._check_value(x, 'value')
if self.power == 0: if self.power == 0:
return self.exp(x) return self.exp(x)
return self.exp(self.log1p(x * self.power) / self.power) return self.exp(self.log1p(x * self.power) / self.power)
def _inverse(self, y): def _inverse(self, y):
self.checktensor(y, 'value') y = self._check_value(y, 'value')
if self.power == 0: if self.power == 0:
return self.log(y) return self.log(y)
return self.expm1(self.log(y) * self.power) / self.power return self.expm1(self.log(y) * self.power) / self.power
@ -103,7 +100,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
""" """
self.checktensor(x, 'value') x = self._check_value(x, 'value')
if self.power == 0: if self.power == 0:
return x return x
return (1. / self.power - 1) * self.log1p(x * self.power) return (1. / self.power - 1) * self.log1p(x * self.power)
@ -120,5 +117,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y} f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
""" """
self.checktensor(y, 'value') y = self._check_value(y, 'value')
return (self.power - 1) * self.log(y) return (self.power - 1) * self.log(y)

@ -15,7 +15,7 @@
"""Scalar Affine Bijector""" """Scalar Affine Bijector"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import log_generic from ..distribution._utils.custom_ops import log_generic
from .bijector import Bijector from .bijector import Bijector
@ -57,8 +57,8 @@ class ScalarAffine(Bijector):
Constructor of scalar affine bijector. Constructor of scalar affine bijector.
""" """
param = dict(locals()) param = dict(locals())
validator.check_value_type('scale', scale, [int, float], name) validator.check_value_type('scale', scale, [int, float], type(self).__name__)
validator.check_value_type('shift', shift, [int, float], name) validator.check_value_type('shift', shift, [int, float], type(self).__name__)
self._scale = cast_to_tensor(scale) self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift) self._shift = cast_to_tensor(shift)
super(ScalarAffine, self).__init__( super(ScalarAffine, self).__init__(
@ -71,8 +71,6 @@ class ScalarAffine(Bijector):
self.abs = P.Abs() self.abs = P.Abs()
self.log = log_generic self.log = log_generic
self.checktensor = CheckTensor()
@property @property
def scale(self): def scale(self):
return self._scale return self._scale
@ -93,7 +91,7 @@ class ScalarAffine(Bijector):
.. math:: .. math::
f(x) = a * x + b f(x) = a * x + b
""" """
self.checktensor(x, 'value') x = self._check_value(x, 'value')
return self.scale * x + self.shift return self.scale * x + self.shift
def _inverse(self, y): def _inverse(self, y):
@ -101,7 +99,7 @@ class ScalarAffine(Bijector):
.. math:: .. math::
f(y) = \frac{y - b}{a} f(y) = \frac{y - b}{a}
""" """
self.checktensor(y, 'value') y = self._check_value(y, 'value')
return (y - self.shift) / self.scale return (y - self.shift) / self.scale
def _forward_log_jacobian(self, x): def _forward_log_jacobian(self, x):
@ -111,7 +109,7 @@ class ScalarAffine(Bijector):
f'(x) = a f'(x) = a
\log(f'(x)) = \log(a) \log(f'(x)) = \log(a)
""" """
self.checktensor(x, 'value') x = self._check_value(x, 'value')
return self.log(self.abs(self.scale)) return self.log(self.abs(self.scale))
def _inverse_log_jacobian(self, y): def _inverse_log_jacobian(self, y):
@ -121,5 +119,5 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a} f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a) \log(f'(x)) = - \log(a)
""" """
self.checktensor(y, 'value') y = self._check_value(y, 'value')
return -1. * self.log(self.abs(self.scale)) return -1. * self.log(self.abs(self.scale))

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.layer.activation import LogSigmoid from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
from .bijector import Bijector from .bijector import Bijector
@ -57,7 +57,7 @@ class Softplus(Bijector):
sharpness=1.0, sharpness=1.0,
name='Softplus'): name='Softplus'):
param = dict(locals()) param = dict(locals())
validator.check_value_type('sharpness', sharpness, [int, float], name) validator.check_value_type('sharpness', sharpness, [int, float], type(self).__name__)
super(Softplus, self).__init__(name=name, param=param) super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness) self._sharpness = cast_to_tensor(sharpness)
@ -76,7 +76,6 @@ class Softplus(Bijector):
self.softplus = self._softplus self.softplus = self._softplus
self.inverse_softplus = self._inverse_softplus self.inverse_softplus = self._inverse_softplus
self.checktensor = CheckTensor()
self.threshold = np.log(np.finfo(np.float32).eps) + 1 self.threshold = np.log(np.finfo(np.float32).eps) + 1
self.tiny = np.exp(self.threshold) self.tiny = np.exp(self.threshold)
@ -119,7 +118,7 @@ class Softplus(Bijector):
return shape return shape
def _forward(self, x): def _forward(self, x):
self.checktensor(x, 'value') x = self._check_value(x, 'value')
scaled_value = self.sharpness * x scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness return self.softplus(scaled_value) / self.sharpness
@ -129,7 +128,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k} f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
""" """
self.checktensor(y, 'value') y = self._check_value(y, 'value')
scaled_value = self.sharpness * y scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness return self.inverse_softplus(scaled_value) / self.sharpness
@ -140,7 +139,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
""" """
self.checktensor(x, 'value') x = self._check_value(x, 'value')
scaled_value = self.sharpness * x scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value) return self.log_sigmoid(scaled_value)
@ -151,6 +150,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1} f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
""" """
self.checktensor(y, 'value') y = self._check_value(y, 'value')
scaled_value = self.sharpness * y scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value) return scaled_value - self.inverse_softplus(scaled_value)

@ -342,7 +342,7 @@ class CheckTuple(PrimitiveWithInfer):
# Pynative mode # Pynative mode
if isinstance(x, tuple): if isinstance(x, tuple):
return x return x
raise TypeError(f"For {name['value']}, Input type should b a tuple.") raise TypeError(f"For {name}, input type should be a tuple.")
class CheckTensor(PrimitiveWithInfer): class CheckTensor(PrimitiveWithInfer):
@ -365,4 +365,6 @@ class CheckTensor(PrimitiveWithInfer):
return out return out
def __call__(self, x, name): def __call__(self, x, name):
return if isinstance(x, Tensor):
return x
raise TypeError(f"For {name}, input type should be a Tensor.")

@ -99,7 +99,7 @@ class Bernoulli(Distribution):
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Bernoulli") check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = mstype.float32
if probs is not None: if probs is not None:
@ -144,7 +144,10 @@ class Bernoulli(Distribution):
Check availablity of distribution specific args probs1. Check availablity of distribution specific args probs1.
""" """
if probs1 is not None: if probs1 is not None:
self.checktensor(probs1, 'probs1') if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type) return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1') return self.probs if self.probs is not None else raise_none_error('probs1')
@ -210,7 +213,7 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1; pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0; pmf(k) = probs0 if k = 0;
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
probs1 = self._check_param(probs1) probs1 = self._check_param(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
@ -229,7 +232,7 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1; cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1; cdf(k) = 1 if k >=1;
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param(probs1)
@ -257,7 +260,7 @@ class Bernoulli(Distribution):
probs0_a * \log(\frac{probs0_a}{probs0_b}) probs0_a * \log(\frac{probs0_a}{probs0_b})
""" """
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
self.checktensor(probs1_b, 'probs1_b') probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1) probs1_a = self._check_param(probs1)
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""basic""" """basic"""
from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
@ -54,7 +55,7 @@ class Distribution(Cell):
Constructor of distribution class. Constructor of distribution class.
""" """
super(Distribution, self).__init__() super(Distribution, self).__init__()
validator.check_value_type('name', name, [str], 'distribution_name') validator.check_value_type('name', name, [str], type(self).__name__)
validator.check_integer('seed', seed, 0, Rel.GE, name) validator.check_integer('seed', seed, 0, Rel.GE, name)
self._name = name self._name = name
@ -81,6 +82,7 @@ class Distribution(Cell):
self._set_log_survival() self._set_log_survival()
self._set_cross_entropy() self._set_cross_entropy()
self.context_mode = context.get_context('mode')
self.checktuple = CheckTuple() self.checktuple = CheckTuple()
self.checktensor = CheckTensor() self.checktensor = CheckTensor()
@ -108,6 +110,15 @@ class Distribution(Cell):
def broadcast_shape(self): def broadcast_shape(self):
return self._broadcast_shape return self._broadcast_shape
def _check_value(self, value, name):
"""
Check availability fo value as a Tensor.
"""
if self.context_mode == 0:
self.checktensor(value, name)
return value
return self.checktensor(value, name)
def _set_prob(self): def _set_prob(self):
""" """
Set probability funtion based on the availability of _prob and _log_likehood. Set probability funtion based on the availability of _prob and _log_likehood.

@ -100,7 +100,7 @@ class Exponential(Distribution):
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Exponential") check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype self.parameter_type = dtype
if rate is not None: if rate is not None:
@ -146,7 +146,10 @@ class Exponential(Distribution):
Check availablity of distribution specific args rate. Check availablity of distribution specific args rate.
""" """
if rate is not None: if rate is not None:
self.checktensor(rate, 'rate') if self.context_mode == 0:
self.checktensor(rate, 'rate')
else:
rate = self.checktensor(rate, 'rate')
return self.cast(rate, self.parameter_type) return self.cast(rate, self.parameter_type)
return self.rate if self.rate is not None else raise_none_error('rate') return self.rate if self.rate is not None else raise_none_error('rate')
@ -210,7 +213,7 @@ class Exponential(Distribution):
.. math:: .. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
self.checktensor(value, "value") value = self._check_value(value, "value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param(rate)
prob = self.exp(self.log(rate) - rate * value) prob = self.exp(self.log(rate) - rate * value)
@ -232,7 +235,7 @@ class Exponential(Distribution):
.. math:: .. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param(rate)
cdf = 1.0 - self.exp(-1. * rate * value) cdf = 1.0 - self.exp(-1. * rate * value)
@ -251,7 +254,7 @@ class Exponential(Distribution):
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): rate of distribution a. Default: self.rate.
""" """
check_distribution_name(dist, 'Exponential') check_distribution_name(dist, 'Exponential')
self.checktensor(rate_b, 'rate_b') rate_b = self._check_value(rate_b, 'rate_b')
rate_b = self.cast(rate_b, self.parameter_type) rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self._check_param(rate) rate_a = self._check_param(rate)
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0

@ -102,7 +102,7 @@ class Geometric(Distribution):
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Geometric") check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = mstype.float32
if probs is not None: if probs is not None:
@ -150,7 +150,10 @@ class Geometric(Distribution):
Check availablity of distribution specific args probs1. Check availablity of distribution specific args probs1.
""" """
if probs1 is not None: if probs1 is not None:
self.checktensor(probs1, 'probs1') if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type) return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1') return self.probs if self.probs is not None else raise_none_error('probs1')
@ -211,7 +214,7 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0. pmf(k) = 0 if k < 0.
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param(probs1)
@ -233,7 +236,7 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0. cdf(k) = 0 if k < 0.
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param(probs1)
@ -256,7 +259,7 @@ class Geometric(Distribution):
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
""" """
check_distribution_name(dist, 'Geometric') check_distribution_name(dist, 'Geometric')
self.checktensor(probs1_b, 'probs1_b') probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1) probs1_a = self._check_param(probs1)
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
@ -102,12 +102,12 @@ class Normal(Distribution):
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Normal") check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype self.parameter_type = dtype
if mean is not None and sd is not None: if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self.broadcast_shape, self.parameter_type) self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, self.parameter_type) self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation") check_greater_zero(self._sd_value, "Standard deviation")
else: else:
self._mean_value = mean self._mean_value = mean
@ -139,12 +139,18 @@ class Normal(Distribution):
Check availablity of distribution specific args mean and sd. Check availablity of distribution specific args mean and sd.
""" """
if mean is not None: if mean is not None:
self.checktensor(mean, 'mean') if self.context_mode == 0:
self.checktensor(mean, 'mean')
else:
mean = self.checktensor(mean, 'mean')
mean = self.cast(mean, self.parameter_type) mean = self.cast(mean, self.parameter_type)
else: else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None: if sd is not None:
self.checktensor(sd, 'sd') if self.context_mode == 0:
self.checktensor(sd, 'sd')
else:
sd = self.checktensor(sd, 'sd')
sd = self.cast(sd, self.parameter_type) sd = self.cast(sd, self.parameter_type)
else: else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
@ -210,7 +216,7 @@ class Normal(Distribution):
.. math:: .. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param(mean, sd)
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
@ -229,7 +235,7 @@ class Normal(Distribution):
.. math:: .. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param(mean, sd)
sqrt2 = self.sqrt(self.const(2.0)) sqrt2 = self.sqrt(self.const(2.0))
@ -252,8 +258,8 @@ class Normal(Distribution):
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
""" """
check_distribution_name(dist, 'Normal') check_distribution_name(dist, 'Normal')
self.checktensor(mean_b, 'mean_b') mean_b = self._check_value(mean_b, 'mean_b')
self.checktensor(sd_b, 'sd_b') sd_b = self._check_value(sd_b, 'sd_b')
mean_b = self.cast(mean_b, self.parameter_type) mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type) sd_b = self.cast(sd_b, self.parameter_type)
mean_a, sd_a = self._check_param(mean, sd) mean_a, sd_a = self._check_param(mean, sd)

@ -46,10 +46,10 @@ class TransformedDistribution(Distribution):
Constructor of transformed_distribution class. Constructor of transformed_distribution class.
""" """
param = dict(locals()) param = dict(locals())
validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], name) validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], type(self).__name__)
validator.check_value_type('distribution', distribution, [Distribution], name) validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__)
valid_dtype = mstype.number_type valid_dtype = mstype.number_type
check_type(dtype, valid_dtype, "transformed_distribution") check_type(dtype, valid_dtype, type(self).__name__)
super(TransformedDistribution, self).__init__(seed, dtype, name, param) super(TransformedDistribution, self).__init__(seed, dtype, name, param)
self._bijector = bijector self._bijector = bijector

@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
@ -101,12 +101,12 @@ class Uniform(Distribution):
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Uniform") check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype self.parameter_type = dtype
if low is not None and high is not None: if low is not None and high is not None:
self._low = convert_to_batch(low, self.broadcast_shape, dtype) self._low = cast_to_tensor(low, dtype)
self._high = convert_to_batch(high, self.broadcast_shape, dtype) self._high = cast_to_tensor(high, dtype)
check_greater(self.low, self.high, "low value", "high value") check_greater(self.low, self.high, "low value", "high value")
else: else:
self._low = low self._low = low
@ -142,12 +142,18 @@ class Uniform(Distribution):
Check availablity of distribution specific args low and high. Check availablity of distribution specific args low and high.
""" """
if low is not None: if low is not None:
self.checktensor(low, 'low') if self.context_mode == 0:
self.checktensor(low, 'low')
else:
low = self.checktensor(low, 'low')
low = self.cast(low, self.parameter_type) low = self.cast(low, self.parameter_type)
else: else:
low = self.low if self.low is not None else raise_none_error('low') low = self.low if self.low is not None else raise_none_error('low')
if high is not None: if high is not None:
self.checktensor(high, 'high') if self.context_mode == 0:
self.checktensor(high, 'high')
else:
high = self.checktensor(high, 'high')
high = self.cast(high, self.parameter_type) high = self.cast(high, self.parameter_type)
else: else:
high = self.high if self.high is not None else raise_none_error('high') high = self.high if self.high is not None else raise_none_error('high')
@ -231,7 +237,7 @@ class Uniform(Distribution):
pdf(x) = \frac{1.0}{high -low} if low <= x <= high; pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high; pdf(x) = 0 if x > high;
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low, high = self._check_param(low, high) low, high = self._check_param(low, high)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0) neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
@ -255,9 +261,9 @@ class Uniform(Distribution):
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): upper bound of distribution a. Default: self.high.
""" """
check_distribution_name(dist, 'Uniform') check_distribution_name(dist, 'Uniform')
self.checktensor(low_b, 'low_b') low_b = self._check_value(low_b, 'low_b')
low_b = self.cast(low_b, self.parameter_type) low_b = self.cast(low_b, self.parameter_type)
self.checktensor(high_b, 'high_b') high_b = self._check_value(high_b, 'high_b')
high_b = self.cast(high_b, self.parameter_type) high_b = self.cast(high_b, self.parameter_type)
low_a, high_a = self._check_param(low, high) low_a, high_a = self._check_param(low, high)
kl = self.log(high_b - low_b) - self.log(high_a - low_a) kl = self.log(high_b - low_b) - self.log(high_a - low_a)
@ -278,7 +284,7 @@ class Uniform(Distribution):
cdf(x) = \frac{x - low}{high -low} if low <= x <= high; cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high; cdf(x) = 1 if x > high;
""" """
self.checktensor(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low, high = self._check_param(low, high) low, high = self._check_param(low, high)
prob = (value - low) / (high - low) prob = (value - low) / (high - low)

Loading…
Cancel
Save