!6517 Improve the logic of constructors and utility functions in distribution

Merge pull request !6517 from XunDeng/add_parameter_v1
pull/6517/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3eff68f8aa

@ -19,17 +19,18 @@ from .utils import *
from .custom_ops import *
__all__ = [
'convert_to_batch',
'cast_to_tensor',
'check_greater',
'check_greater_equal_zero',
'check_greater_zero',
'calc_broadcast_shape_from_param',
'check_scalar_from_param',
'check_prob',
'check_type',
'exp_generic',
'expm1_generic',
'log_generic',
'log1p_generic',
'broadcast_to',
'set_param_type',
'CheckTensor',
'CheckTuple',
]

@ -72,3 +72,12 @@ def log1p_generic(x):
Log1p ops on GPU device or when device_target == GPU.
"""
return log_generic(x + 1.0)
def broadcast_to(x, target):
"""
Broadcast x to the shape of target.
"""
shape = P.Shape()
if shape(x) == shape(target):
return x
return P.BroadcastTo(shape(target))(x)

@ -19,13 +19,10 @@ from mindspore._checkparam import Validator as validator
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import _utils as utils
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
import mindspore.nn as nn
import mindspore.nn.probability as msp
def cast_to_tensor(t, hint_type=mstype.float32):
"""
@ -46,41 +43,13 @@ def cast_to_tensor(t, hint_type=mstype.float32):
raise ValueError(f'Input cannot be None in cast_to_tensor')
if isinstance(t, Parameter):
return t
t_type = hint_type
if isinstance(t, Tensor):
# convert the type of tensor to dtype
return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=t_type)
if isinstance(t, bool):
raise TypeError(f'Input cannot be Type Bool')
if isinstance(t, (int, float)):
return Tensor(t, dtype=t_type)
if isinstance(t, (Tensor, np.ndarray, list, int, float)):
return Tensor(t, dtype=hint_type)
invalid_type = type(t)
raise TypeError(
f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
def convert_to_batch(t, batch_shape, required_type):
"""
Convert a Tensor to a given batch shape.
Args:
t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted.
batch_shape (tuple): desired batch shape.
dtype (mindspore.dtype): desired dtype.
Raises:
RuntimeError: if the converison cannot be done.
Returns:
Tensor, with shape of batch_shape.
"""
if isinstance(t, Parameter):
return t
t = cast_to_tensor(t, required_type)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}")
def cast_type_for_device(dtype):
"""
@ -100,54 +69,6 @@ def cast_type_for_device(dtype):
return dtype
def check_scalar_from_param(params):
"""
Check if params are all scalars.
Args:
params (dict): parameters used to initialize distribution.
Notes: String parameters are excluded.
"""
for value in params.values():
if value is None:
continue
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch
if isinstance(value, Parameter):
return False
if not isinstance(value, (int, float, str, type(params['dtype']))):
return False
return True
def calc_broadcast_shape_from_param(params):
"""
Calculate the broadcast shape from params.
Args:
params (dict): parameters used to initialize distribution.
Returns:
tuple.
"""
broadcast_shape = []
for value in params.values():
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].broadcast_shape
if isinstance(value, (str, type(params['dtype']))):
continue
if value is None:
return None
if isinstance(value, Parameter):
value_t = value.data
else:
value_t = cast_to_tensor(value, mstype.float32)
broadcast_shape = utils.get_broadcast_shape(
broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape)
def check_greater_equal_zero(value, name):
"""
Check if the given Tensor is greater zero.
@ -371,6 +292,9 @@ def set_param_type(args, hint_type):
Raises:
TypeError: if tensors in args are not the same dtype.
"""
int_type = mstype.int_type + mstype.uint_type
if hint_type in int_type:
hint_type = mstype.float32
common_dtype = None
for name, arg in args.items():
if hasattr(arg, 'dtype'):
@ -382,7 +306,6 @@ def set_param_type(args, hint_type):
common_dtype = cur_dtype
elif cur_dtype != common_dtype:
raise TypeError(f"{name} should have the same dtype as other arguments.")
int_type = mstype.int_type + mstype.uint_type
if common_dtype in int_type or common_dtype == mstype.float64:
return mstype.float32
return hint_type if common_dtype is None else common_dtype

@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -116,18 +116,14 @@ class Bernoulli(Distribution):
Constructor of Bernoulli.
"""
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self.probs)
else:
self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self._probs = self._add_parameter(probs, 'probs')
if self._probs is not None:
check_prob(self.probs)
# ops needed for the class
self.exp = exp_generic
@ -135,14 +131,11 @@ class Bernoulli(Distribution):
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.floor = P.Floor()
self.fill = P.Fill()
self.less = P.Less()
self.shape = P.Shape()
self.select = P.Select()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = C.uniform
def extend_repr(self):
@ -173,9 +166,8 @@ class Bernoulli(Distribution):
MODE(B) = 1 if probs1 > 0.5 else = 0
"""
probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0)
zeros = self.fill(self.dtype, self.shape(probs1), 0.0)
ones = self.fill(self.dtype, self.shape(probs1), 1.0)
comp = self.less(0.5, probs1)
return self.select(comp, ones, zeros)
@ -244,13 +236,13 @@ class Bernoulli(Distribution):
value = self.cast(value, self.parameter_type)
value = self.floor(value)
probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
broadcast_shape_tensor = value * probs1
value = self.broadcast(value, broadcast_shape_tensor)
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
comp_zero = self.less(value, 0.0)
comp_one = self.less(value, 1.0)
zeros = self.fill(prob_type, self.shape(value), 0.0)
ones = self.fill(prob_type, self.shape(value), 1.0)
zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0)
ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0)
less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones)

@ -14,13 +14,14 @@
# ============================================================================
"""basic"""
from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import get_seed
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\
raise_none_error
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device
from ._utils.utils import CheckTuple, CheckTensor
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
class Distribution(Cell):
@ -68,14 +69,16 @@ class Distribution(Cell):
self._seed = seed
self._dtype = cast_type_for_device(dtype)
self._parameters = {}
# parsing parameters
for k in param.keys():
if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k]
# some attributes
self._broadcast_shape = calc_broadcast_shape_from_param(
self.parameters)
self._is_scalar_batch = check_scalar_from_param(self.parameters)
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
self._broadcast_shape = self._calc_broadcast_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
# set the function to call according to the derived class's attributes
self._set_prob()
@ -91,6 +94,18 @@ class Distribution(Cell):
self.context_mode = context.get_context('mode')
self.checktuple = CheckTuple()
self.checktensor = CheckTensor()
self.broadcast = broadcast_to
# ops needed for the base class
self.cast_base = P.Cast()
self.dtype_base = P.DType()
self.exp_base = exp_generic
self.fill_base = P.Fill()
self.log_base = log_generic
self.sametypeshape_base = P.SameTypeShape()
self.sq_base = P.Square()
self.sqrt_base = P.Sqrt()
self.shape_base = P.Shape()
@property
def name(self):
@ -116,6 +131,21 @@ class Distribution(Cell):
def broadcast_shape(self):
return self._broadcast_shape
def _add_parameter(self, value, name):
"""
Cast `value` to a tensor and add it to `self.default_parameters`.
Add `name` into and `self.parameter_names`.
"""
# initialize the attributes if they do not exist yet
if not hasattr(self, 'default_parameters'):
self.default_parameters = []
self.parameter_names = []
# cast value to a tensor if it is not None
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
self.default_parameters += [value_t,]
self.parameter_names += [name,]
return value_t
def _check_param_type(self, *args):
"""
Check the availability and validity of default parameters and `dist_spec_args`.
@ -123,6 +153,7 @@ class Distribution(Cell):
are None, the parameters must be passed in through `args`.
"""
broadcast_shape = None
broadcast_shape_tensor = None
common_dtype = None
out = []
@ -139,17 +170,17 @@ class Distribution(Cell):
# broadcast if the number of args > 1
if broadcast_shape is None:
broadcast_shape = self.shape(arg)
common_dtype = self.dtypeop(arg)
broadcast_shape = self.shape_base(arg)
common_dtype = self.dtype_base(arg)
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
else:
ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
broadcast_shape = self.shape(arg + ones)
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
arg = self.broadcast(arg, broadcast_shape_tensor)
# check if the arguments have the same dtype
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0)
self.sametypeshape(arg, dtype_tensor)
arg = self.cast(arg, self.parameter_type)
self.sametypeshape_base(arg, broadcast_shape_tensor)
arg = self.cast_base(arg, self.parameter_type)
out.append(arg)
if len(out) == 1:
@ -158,7 +189,7 @@ class Distribution(Cell):
# broadcast all args to broadcast_shape
result = ()
for arg in out:
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
arg = self.broadcast(arg, broadcast_shape_tensor)
result = result + (arg,)
return result
@ -171,6 +202,38 @@ class Distribution(Cell):
return value
return self.checktensor(value, name)
def _check_is_scalar_batch(self):
"""
Check if the parameters used during initialization are scalars.
"""
if hasattr(self, 'distribution'):
return self._distribution.is_scalar_batch
param_dict = self.parameters['param_dict']
for value in param_dict.values():
if value is None:
continue
if not isinstance(value, (int, float)):
return False
return True
def _calc_broadcast_shape(self):
"""
Calculate the broadcast shape of the parameters used during initialization.
"""
if hasattr(self, 'distribution'):
return self._distribution.broadcast_shape
param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None
for value in param_dict.values():
if value is None:
return None
if broadcast_shape_tensor is None:
broadcast_shape_tensor = cast_to_tensor(value)
else:
value = cast_to_tensor(value)
broadcast_shape_tensor = (value + broadcast_shape_tensor)
return broadcast_shape_tensor.shape
def _set_prob(self):
"""
Set probability funtion based on the availability of `_prob` and `_log_likehood`.
@ -280,7 +343,7 @@ class Distribution(Cell):
.. math::
probability(x) = \exp(log_likehood(x))
"""
return self.exp(self._log_prob(value, *args, **kwargs))
return self.exp_base(self._log_prob(value, *args, **kwargs))
def prob(self, value, *args, **kwargs):
"""
@ -304,7 +367,7 @@ class Distribution(Cell):
.. math::
log_prob(x) = \log(prob(x))
"""
return self.log(self._prob(value, *args, **kwargs))
return self.log_base(self._prob(value, *args, **kwargs))
def cdf(self, value, *args, **kwargs):
"""
@ -328,7 +391,7 @@ class Distribution(Cell):
.. math::
cdf(x) = \exp(log_cdf(x))
"""
return self.exp(self._log_cdf(value, *args, **kwargs))
return self.exp_base(self._log_cdf(value, *args, **kwargs))
def _calc_cdf_from_survival(self, value, *args, **kwargs):
r"""
@ -346,7 +409,7 @@ class Distribution(Cell):
.. math::
cdf(x) = 1 - (\exp(log_survival(x)))
"""
return 1.0 - self.exp(self._log_survival(value, *args, **kwargs))
return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs))
def log_cdf(self, value, *args, **kwargs):
"""
@ -370,7 +433,7 @@ class Distribution(Cell):
.. math::
log_cdf(x) = \log(cdf(x))
"""
return self.log(self._call_cdf(value, *args, **kwargs))
return self.log_base(self._call_cdf(value, *args, **kwargs))
def survival_function(self, value, *args, **kwargs):
"""
@ -403,7 +466,7 @@ class Distribution(Cell):
.. math::
survival(x) = \exp(survival_function(x))
"""
return self.exp(self._log_survival(value, *args, **kwargs))
return self.exp_base(self._log_survival(value, *args, **kwargs))
def log_survival(self, value, *args, **kwargs):
"""
@ -427,7 +490,7 @@ class Distribution(Cell):
.. math::
log_survival(x) = \log(survival_function(x))
"""
return self.log(self._call_survival(value, *args, **kwargs))
return self.log_base(self._call_survival(value, *args, **kwargs))
def kl_loss(self, dist, *args, **kwargs):
"""
@ -507,7 +570,7 @@ class Distribution(Cell):
.. math::
STD(x) = \sqrt(VAR(x))
"""
return self.sqrt(self._var(*args, **kwargs))
return self.sqrt_base(self._var(*args, **kwargs))
def _calc_var_from_sd(self, *args, **kwargs):
r"""
@ -516,7 +579,7 @@ class Distribution(Cell):
.. math::
VAR(x) = STD(x) ^ 2
"""
return self.sq(self._sd(*args, **kwargs))
return self.sq_base(self._sd(*args, **kwargs))
def entropy(self, *args, **kwargs):
"""

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -118,18 +118,14 @@ class Exponential(Distribution):
Constructor of Exponential.
"""
param = dict(locals())
param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type({'rate': rate}, self.dtype)
if rate is not None:
self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate")
else:
self._rate = rate
self.default_parameters = [self.rate]
self.parameter_names = ['rate']
self._rate = self._add_parameter(rate, 'rate')
if self.rate is not None:
check_greater_zero(self.rate, 'rate')
self.minval = np.finfo(np.float).tiny
@ -144,8 +140,6 @@ class Exponential(Distribution):
self.less = P.Less()
self.select = P.Select()
self.shape = P.Shape()
self.sqrt = P.Sqrt()
self.sq = P.Square()
self.uniform = C.uniform
def extend_repr(self):

@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
set_param_type
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -120,18 +119,14 @@ class Geometric(Distribution):
Constructor of Geometric distribution.
"""
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs)
else:
self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self._probs = self._add_parameter(probs, 'probs')
if self._probs is not None:
check_prob(self.probs)
self.minval = np.finfo(np.float).tiny
@ -150,7 +145,6 @@ class Geometric(Distribution):
self.select = P.Select()
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.uniform = C.uniform
def extend_repr(self):
@ -181,7 +175,7 @@ class Geometric(Distribution):
MODE(Geo) = 0
"""
probs1 = self._check_param_type(probs1)
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
return self.fill(self.dtype, self.shape(probs1), 0.)
def _var(self, probs1=None):
r"""
@ -229,7 +223,7 @@ class Geometric(Distribution):
value = self.floor(value)
probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
zeros = self.fill(self.dtypeop(pmf), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, pmf)
@ -252,7 +246,7 @@ class Geometric(Distribution):
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)

@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
set_param_type
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
@ -125,23 +124,15 @@ class Normal(Distribution):
Constructor of Normal.
"""
param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type(
{'mean': mean, 'sd': sd}, self.dtype)
if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean if mean is None else cast_to_tensor(
mean, self.parameter_type)
self._sd_value = sd if sd is None else cast_to_tensor(
sd, self.parameter_type)
self.default_parameters = [self._mean_value, self._sd_value]
self.parameter_names = ['mean', 'sd']
self._mean_value = self._add_parameter(mean, 'mean')
self._sd_value = self._add_parameter(sd, 'sd')
if self._sd_value is not None:
check_greater_zero(self._sd_value, "Standard deviation")
# ops needed for the class
self.exp = exp_generic
@ -151,13 +142,9 @@ class Normal(Distribution):
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.fill = P.Fill()
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike()
self.dtypeop = P.DType()
self.sametypeshape = P.SameTypeShape()
def extend_repr(self):
if self.is_scalar_batch:

@ -81,6 +81,8 @@ class TransformedDistribution(Distribution):
self._bijector = bijector
self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian
self.default_parameters = distribution.default_parameters
self.parameter_names = distribution.parameter_names
self.exp = exp_generic
self.log = log_generic

@ -17,8 +17,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
set_param_type
from ._utils.utils import check_greater, check_type, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -124,23 +123,16 @@ class Uniform(Distribution):
Constructor of Uniform distribution.
"""
param = dict(locals())
param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = set_param_type(
{'low': low, 'high': high}, self.dtype)
if low is not None and high is not None:
self._low = cast_to_tensor(low, self.parameter_type)
self._high = cast_to_tensor(high, self.parameter_type)
check_greater(self.low, self.high, "low value", "high value")
else:
self._low = low if low is None else cast_to_tensor(
low, self.parameter_type)
self._high = high if high is None else cast_to_tensor(
high, self.parameter_type)
self.default_parameters = [self.low, self.high]
self.parameter_names = ['low', 'high']
self._low = self._add_parameter(low, 'low')
self._high = self._add_parameter(high, 'high')
if self.low is not None and self.high is not None:
check_greater(self.low, self.high, 'low', 'high')
# ops needed for the class
self.exp = exp_generic
@ -156,12 +148,9 @@ class Uniform(Distribution):
self.select = P.Select()
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike()
self.uniform = C.uniform
self.sametypeshape = P.SameTypeShape()
def extend_repr(self):
if self.is_scalar_batch:
str_info = f'low = {self.low}, high = {self.high}'

Loading…
Cancel
Save