fix bincount

pull/14668/head^2
huangmengxi 4 years ago
parent 22e10059b8
commit 2dfb0db4bd

@ -4636,9 +4636,7 @@ def bincount(x, weights=None, minlength=0, length=None):
if F.shape(x) != F.shape(weights):
_raise_value_error('`x` and `weights` must have the same length')
idx_mapping *= weights
if _get_device() == 'Ascend':
idx_mapping = F.cast(idx_mapping, mstype.float32)
return F.reduce_sum(idx_mapping, 1).ravel().astype(mstype.int32)
return F.reduce_sum(idx_mapping.astype(mstype.float32), 1).ravel()
def histogram(a, bins=10, range=None, weights=None, density=False): # pylint: disable=redefined-builtin

Loading…
Cancel
Save