fix _check_values_dtype_in_probs method in Distribution class (#27046)

numel
pangyoki 5 years ago committed by GitHub
parent b150f2b3a6
commit f0b26313d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -138,7 +138,7 @@ class Distribution(object):
convert value's dtype to be consistent with param's dtype.
Args:
param (int|float|list|numpy.ndarray|Tensor): low and high in Uniform class, loc and scale in Normal class.
param (Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor.
Returns:
@ -152,6 +152,7 @@ class Distribution(object):
)
return core.ops.cast(value, 'in_dtype', value.dtype,
'out_dtype', param.dtype)
return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')

Loading…
Cancel
Save