|
|
|
@ -138,7 +138,7 @@ class Distribution(object):
|
|
|
|
convert value's dtype to be consistent with param's dtype.
|
|
|
|
convert value's dtype to be consistent with param's dtype.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
param (int|float|list|numpy.ndarray|Tensor): low and high in Uniform class, loc and scale in Normal class.
|
|
|
|
param (Tensor): low and high in Uniform class, loc and scale in Normal class.
|
|
|
|
value (Tensor): The input tensor.
|
|
|
|
value (Tensor): The input tensor.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
@ -152,6 +152,7 @@ class Distribution(object):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return core.ops.cast(value, 'in_dtype', value.dtype,
|
|
|
|
return core.ops.cast(value, 'in_dtype', value.dtype,
|
|
|
|
'out_dtype', param.dtype)
|
|
|
|
'out_dtype', param.dtype)
|
|
|
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
|
|
|
|
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
|
|
|
|
'log_prob')
|
|
|
|
'log_prob')
|
|
|
|
|