Remove the redundant parameters from bijectors and transformed distribution

pull/6862/head
peixu_ren 4 years ago
parent 70221f5261
commit c7563d53bf

@ -49,5 +49,4 @@ class Exp(PowerTransform):
def __init__(self,
name='Exp'):
param = dict(locals())
super(Exp, self).__init__(name=name, param=param)
super(Exp, self).__init__(name=name)

@ -39,9 +39,6 @@ class PowerTransform(Bijector):
Args:
power (int or float): The scale factor. Default: 0.
name (str): The name of the bijector. Default: 'PowerTransform'.
param (dict): The parameters used to initialize the bijector. These parameters are only used when other
Bijectors inherit from powertransform to pass in parameters. In this case the derived Bijector may overwrite
the argument `param`. Default: None.
Examples:
>>> # To initialize a PowerTransform bijector of power 0.5.
@ -65,9 +62,8 @@ class PowerTransform(Bijector):
def __init__(self,
power=0,
name='PowerTransform',
param=None):
param = dict(locals()) if param is None else param
name='PowerTransform'):
param = dict(locals())
super(PowerTransform, self).__init__(name=name, param=param)
validator.check_value_type('power', power, [int, float], self.name)
validator.check_number("power", power, 0, Rel.GE, self.name)

@ -135,7 +135,7 @@ class LogNormal(msd.TransformedDistribution):
"""
super(LogNormal, self).__init__(distribution=msd.Normal(loc, scale, dtype=dtype),
bijector=msb.Exp(),
dtype=dtype, seed=seed, name=name)
seed=seed, name=name)
self.log_2pi = np.log(2 * np.pi)

@ -14,10 +14,9 @@
# ============================================================================
"""Transformed Distribution"""
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import check_type, raise_not_impl_error
from ._utils.utils import raise_not_impl_error
from ._utils.custom_ops import exp_generic, log_generic
@ -30,7 +29,6 @@ class TransformedDistribution(Distribution):
Args:
bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution.
dtype (mindspore.dtype): The type of the event samples.
seed (int): The seed is used in sampling. The global seed is used if it is None.
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.
@ -45,16 +43,14 @@ class TransformedDistribution(Distribution):
>>> import mindspore.nn.probability.distribution as msd
>>> import mindspore.nn.probability.bijector as msb
>>> ln = msd.TransformedDistribution(msb.Exp(),
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
>>> dtype=mstype.float32)
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
>>>
>>> # To use a transformed distribution in a network.
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.ln = msd.TransformedDistribution(msb.Exp(),
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32),
>>> dtype=mstype.float32)
>>> msd.Normal(0.0, 1.0, dtype=mstype.float32))
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other functions
@ -65,7 +61,6 @@ class TransformedDistribution(Distribution):
def __init__(self,
bijector,
distribution,
dtype,
seed=None,
name="transformed_distribution"):
"""
@ -76,9 +71,7 @@ class TransformedDistribution(Distribution):
[nn.probability.bijector.Bijector], type(self).__name__)
validator.check_value_type('distribution', distribution,
[Distribution], type(self).__name__)
valid_dtype = mstype.number_type
check_type(dtype, valid_dtype, type(self).__name__)
super(TransformedDistribution, self).__init__(seed, dtype, name, param)
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param)
self._bijector = bijector
self._distribution = distribution
@ -96,6 +89,10 @@ class TransformedDistribution(Distribution):
def distribution(self):
return self._distribution
@property
def dtype(self):
return self.distribution.dtype
@property
def is_linear_transformation(self):
return self._is_linear_transformation

Loading…
Cancel
Save